mediapipe/mediapipe/framework/tool/proto_util_lite.cc
MediaPipe Team d68f5e4169 Project import generated by Copybara.
PiperOrigin-RevId: 253489161
2019-06-16 16:06:57 -07:00

394 lines
15 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/tool/proto_util_lite.h"
#include <tuple>
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/type_map.h"
namespace mediapipe {
namespace tool {
using proto_ns::io::ArrayInputStream;
using proto_ns::io::CodedInputStream;
using proto_ns::io::CodedOutputStream;
using proto_ns::io::StringOutputStream;
using WireFormatLite = ProtoUtilLite::WireFormatLite;
using FieldAccess = ProtoUtilLite::FieldAccess;
using FieldValue = ProtoUtilLite::FieldValue;
using ProtoPath = ProtoUtilLite::ProtoPath;
using FieldType = ProtoUtilLite::FieldType;
// Returns true if a wire type includes a length indicator.
bool IsLengthDelimited(WireFormatLite::WireType wire_type) {
return wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED;
}
// Reads a single data value for a wire type.
::mediapipe::Status ReadFieldValue(uint32 tag, CodedInputStream* in,
std::string* result) {
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
if (IsLengthDelimited(wire_type)) {
uint32 length;
RET_CHECK(in->ReadVarint32(&length));
RET_CHECK(in->ReadString(result, length));
} else {
std::string field_data;
StringOutputStream sos(&field_data);
CodedOutputStream cos(&sos);
RET_CHECK(WireFormatLite::SkipField(in, tag, &cos));
// Skip the tag written by SkipField.
int tag_size = CodedOutputStream::VarintSize32(tag);
cos.Trim();
result->assign(field_data, tag_size, std::string::npos);
}
return ::mediapipe::OkStatus();
}
// Reads the packed sequence of data values for a wire type.
::mediapipe::Status ReadPackedValues(WireFormatLite::WireType wire_type,
CodedInputStream* in,
std::vector<std::string>* field_values) {
uint32 data_size;
RET_CHECK(in->ReadVarint32(&data_size));
// fake_tag encodes the wire-type for calls to WireFormatLite::SkipField.
uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type);
while (data_size > 0) {
std::string number;
RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number));
RET_CHECK_LE(number.size(), data_size);
field_values->push_back(number);
data_size -= number.size();
}
return ::mediapipe::OkStatus();
}
// Extracts the data value(s) for one field from a serialized message.
// The message with these field values removed is written to |out|.
::mediapipe::Status GetFieldValues(uint32 field_id,
WireFormatLite::WireType wire_type,
CodedInputStream* in, CodedOutputStream* out,
std::vector<std::string>* field_values) {
uint32 tag;
while ((tag = in->ReadTag()) != 0) {
int field_number = WireFormatLite::GetTagFieldNumber(tag);
if (field_number == field_id) {
if (!IsLengthDelimited(wire_type) &&
IsLengthDelimited(WireFormatLite::GetTagWireType(tag))) {
RETURN_IF_ERROR(ReadPackedValues(wire_type, in, field_values));
} else {
std::string value;
RETURN_IF_ERROR(ReadFieldValue(tag, in, &value));
field_values->push_back(value);
}
} else {
RET_CHECK(WireFormatLite::SkipField(in, tag, out));
}
}
return ::mediapipe::OkStatus();
}
// Injects the data value(s) for one field into a serialized message.
void SetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
const std::vector<std::string>& field_values,
CodedOutputStream* out) {
uint32 tag = WireFormatLite::MakeTag(field_id, wire_type);
for (const std::string& field_value : field_values) {
out->WriteVarint32(tag);
if (IsLengthDelimited(wire_type)) {
out->WriteVarint32(field_value.length());
}
out->WriteRaw(field_value.data(), field_value.length());
}
}
FieldAccess::FieldAccess(uint32 field_id, FieldType field_type)
: field_id_(field_id), field_type_(field_type) {}
::mediapipe::Status FieldAccess::SetMessage(const std::string& message) {
ArrayInputStream ais(message.data(), message.size());
CodedInputStream in(&ais);
StringOutputStream sos(&message_);
CodedOutputStream out(&sos);
WireFormatLite::WireType wire_type =
WireFormatLite::WireTypeForFieldType(field_type_);
return GetFieldValues(field_id_, wire_type, &in, &out, &field_values_);
}
void FieldAccess::GetMessage(std::string* result) {
*result = message_;
StringOutputStream sos(result);
CodedOutputStream out(&sos);
WireFormatLite::WireType wire_type =
WireFormatLite::WireTypeForFieldType(field_type_);
SetFieldValues(field_id_, wire_type, field_values_, &out);
}
std::vector<FieldValue>* FieldAccess::mutable_field_values() {
return &field_values_;
}
// Replaces a range of field values for one field nested within a protobuf.
::mediapipe::Status ProtoUtilLite::ReplaceFieldRange(
FieldValue* message, ProtoPath proto_path, int length, FieldType field_type,
const std::vector<FieldValue>& field_values) {
int field_id, index;
std::tie(field_id, index) = proto_path.front();
proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty()
? WireFormatLite::TYPE_MESSAGE
: field_type);
RETURN_IF_ERROR(access.SetMessage(*message));
std::vector<std::string>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK(index >= 0 && index < v.size());
RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, field_type,
field_values));
} else {
RET_CHECK(index >= 0 && index <= v.size());
RET_CHECK(index + length >= 0 && index + length <= v.size());
v.erase(v.begin() + index, v.begin() + index + length);
v.insert(v.begin() + index, field_values.begin(), field_values.end());
}
message->clear();
access.GetMessage(message);
return ::mediapipe::OkStatus();
}
// Returns a range of field values from one field nested within a protobuf.
::mediapipe::Status ProtoUtilLite::GetFieldRange(
const FieldValue& message, ProtoPath proto_path, int length,
FieldType field_type, std::vector<FieldValue>* field_values) {
int field_id, index;
std::tie(field_id, index) = proto_path.front();
proto_path.erase(proto_path.begin());
FieldAccess access(field_id, !proto_path.empty()
? WireFormatLite::TYPE_MESSAGE
: field_type);
RETURN_IF_ERROR(access.SetMessage(message));
std::vector<std::string>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK(index >= 0 && index < v.size());
RETURN_IF_ERROR(
GetFieldRange(v[index], proto_path, length, field_type, field_values));
} else {
RET_CHECK(index >= 0 && index <= v.size());
RET_CHECK(index + length >= 0 && index + length <= v.size());
field_values->insert(field_values->begin(), v.begin() + index,
v.begin() + index + length);
}
return ::mediapipe::OkStatus();
}
// If ok, returns OkStatus, otherwise returns InvalidArgumentError.
template <typename T>
::mediapipe::Status SyntaxStatus(bool ok, const std::string& text, T* result) {
return ok ? ::mediapipe::OkStatus()
: ::mediapipe::InvalidArgumentError(absl::StrCat(
"Syntax error: \"", text, "\"",
" for type: ", MediaPipeTypeStringOrDemangled<T>(), "."));
}
// Templated parsing of a std::string value.
template <typename T>
::mediapipe::Status ParseValue(const std::string& text, T* result) {
return SyntaxStatus(absl::SimpleAtoi(text, result), text, result);
}
template <>
::mediapipe::Status ParseValue<double>(const std::string& text,
double* result) {
return SyntaxStatus(absl::SimpleAtod(text, result), text, result);
}
template <>
::mediapipe::Status ParseValue<float>(const std::string& text, float* result) {
return SyntaxStatus(absl::SimpleAtof(text, result), text, result);
}
template <>
::mediapipe::Status ParseValue<bool>(const std::string& text, bool* result) {
return SyntaxStatus(absl::SimpleAtob(text, result), text, result);
}
template <>
::mediapipe::Status ParseValue<std::string>(const std::string& text,
std::string* result) {
*result = text;
return ::mediapipe::OkStatus();
}
// Templated formatting of a primitive value.
template <typename T>
std::string FormatValue(T v) {
return FieldValue(absl::StrCat(v));
}
// A helper function to parse and serialize one primtive value.
template <typename T>
::mediapipe::Status WritePrimitive(
void (*writer)(T, proto_ns::io::CodedOutputStream*),
const std::string& text, CodedOutputStream* out) {
T value;
RETURN_IF_ERROR(ParseValue<T>(text, &value));
(*writer)(value, out);
return ::mediapipe::OkStatus();
}
// Serializes a protobuf FieldValue.
static ::mediapipe::Status SerializeValue(const std::string& text,
FieldType field_type,
FieldValue* field_value) {
::mediapipe::Status status;
StringOutputStream sos(field_value);
CodedOutputStream out(&sos);
using W = WireFormatLite;
switch (field_type) {
case W::TYPE_DOUBLE:
return WritePrimitive(W::WriteDoubleNoTag, text, &out);
case W::TYPE_FLOAT:
return WritePrimitive(W::WriteFloatNoTag, text, &out);
case W::TYPE_INT64:
return WritePrimitive(W::WriteInt64NoTag, text, &out);
case W::TYPE_UINT64:
return WritePrimitive(W::WriteUInt64NoTag, text, &out);
case W::TYPE_INT32:
return WritePrimitive(W::WriteInt32NoTag, text, &out);
case W::TYPE_FIXED64:
return WritePrimitive(W::WriteFixed64NoTag, text, &out);
case W::TYPE_FIXED32:
return WritePrimitive(W::WriteFixed32NoTag, text, &out);
case W::TYPE_BOOL: {
return WritePrimitive(W::WriteBoolNoTag, text, &out);
}
case W::TYPE_BYTES:
case W::TYPE_STRING: {
out.WriteRaw(text.data(), text.size());
return ::mediapipe::OkStatus();
}
case W::TYPE_GROUP:
case W::TYPE_MESSAGE:
return ::mediapipe::UnimplementedError(
"SerializeValue cannot serialize a Message.");
case W::TYPE_UINT32:
return WritePrimitive(W::WriteUInt32NoTag, text, &out);
case W::TYPE_ENUM:
return WritePrimitive(W::WriteEnumNoTag, text, &out);
case W::TYPE_SFIXED32:
return WritePrimitive(W::WriteSFixed32NoTag, text, &out);
case W::TYPE_SFIXED64:
return WritePrimitive(W::WriteSFixed64NoTag, text, &out);
case W::TYPE_SINT32:
return WritePrimitive(W::WriteSInt32NoTag, text, &out);
case W::TYPE_SINT64:
return WritePrimitive(W::WriteSInt64NoTag, text, &out);
}
return ::mediapipe::UnimplementedError("SerializeValue unimplemented type.");
}
// A helper function for deserializing one text value.
template <typename CType, FieldType DeclaredType>
static ::mediapipe::Status ReadPrimitive(CodedInputStream* input,
std::string* result) {
CType value;
if (!WireFormatLite::ReadPrimitive<CType, DeclaredType>(input, &value)) {
return ::mediapipe::InvalidArgumentError(absl::StrCat(
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<CType>(),
"."));
}
*result = FormatValue(value);
return ::mediapipe::OkStatus();
}
// Deserializes a protobuf FieldValue.
static ::mediapipe::Status DeserializeValue(const FieldValue& bytes,
FieldType field_type,
std::string* result) {
ArrayInputStream ais(bytes.data(), bytes.size());
CodedInputStream input(&ais);
typedef WireFormatLite W;
switch (field_type) {
case W::TYPE_DOUBLE:
return ReadPrimitive<double, W::TYPE_DOUBLE>(&input, result);
case W::TYPE_FLOAT:
return ReadPrimitive<float, W::TYPE_FLOAT>(&input, result);
case W::TYPE_INT64:
return ReadPrimitive<proto_int64, W::TYPE_INT64>(&input, result);
case W::TYPE_UINT64:
return ReadPrimitive<proto_uint64, W::TYPE_UINT64>(&input, result);
case W::TYPE_INT32:
return ReadPrimitive<int32, W::TYPE_INT32>(&input, result);
case W::TYPE_FIXED64:
return ReadPrimitive<proto_uint64, W::TYPE_FIXED64>(&input, result);
case W::TYPE_FIXED32:
return ReadPrimitive<uint32, W::TYPE_FIXED32>(&input, result);
case W::TYPE_BOOL:
return ReadPrimitive<bool, W::TYPE_BOOL>(&input, result);
case W::TYPE_BYTES:
case W::TYPE_STRING: {
*result = bytes;
return ::mediapipe::OkStatus();
}
case W::TYPE_GROUP:
case W::TYPE_MESSAGE:
CHECK(false) << "DeserializeValue cannot deserialize a Message.";
case W::TYPE_UINT32:
return ReadPrimitive<uint32, W::TYPE_UINT32>(&input, result);
case W::TYPE_ENUM:
return ReadPrimitive<int, W::TYPE_ENUM>(&input, result);
case W::TYPE_SFIXED32:
return ReadPrimitive<int32, W::TYPE_SFIXED32>(&input, result);
case W::TYPE_SFIXED64:
return ReadPrimitive<proto_int64, W::TYPE_SFIXED64>(&input, result);
case W::TYPE_SINT32:
return ReadPrimitive<int32, W::TYPE_SINT32>(&input, result);
case W::TYPE_SINT64:
return ReadPrimitive<proto_int64, W::TYPE_SINT64>(&input, result);
}
return ::mediapipe::UnimplementedError(
"DeserializeValue unimplemented type.");
}
::mediapipe::Status ProtoUtilLite::Serialize(
const std::vector<std::string>& text_values, FieldType field_type,
std::vector<FieldValue>* result) {
result->clear();
result->reserve(text_values.size());
for (const std::string& text_value : text_values) {
FieldValue field_value;
RETURN_IF_ERROR(SerializeValue(text_value, field_type, &field_value));
result->push_back(field_value);
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status ProtoUtilLite::Deserialize(
const std::vector<FieldValue>& field_values, FieldType field_type,
std::vector<std::string>* result) {
result->clear();
result->reserve(field_values.size());
for (const FieldValue& field_value : field_values) {
std::string text_value;
RETURN_IF_ERROR(DeserializeValue(field_value, field_type, &text_value));
result->push_back(text_value);
}
return ::mediapipe::OkStatus();
}
} // namespace tool
} // namespace mediapipe