mediapipe/mediapipe/framework/encode_binary_proto.bzl
Liam Miller-Cushon da36468409 Internal change
PiperOrigin-RevId: 488065083
2022-11-12 10:13:58 -08:00

205 lines
6.9 KiB
Python

# Copyright 2019-2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A rule for encoding a text format protocol buffer into binary.
Example usage:
proto_library(
name = "calculator_proto",
srcs = ["calculator.proto"],
)
encode_binary_proto(
name = "foo_binary",
deps = [":calculator_proto"],
message_type = "mediapipe.CalculatorGraphConfig",
input = "foo.pbtxt",
)
Args:
name: The name of this target.
deps: A list of proto_library targets that define messages that may be
used in the input file.
input: The text format protocol buffer.
message_type: The root message of the buffer.
output: The desired name of the output file. Optional.
"""
PROTOC = "@com_google_protobuf//:protoc"
def _canonicalize_proto_path_oss(all_protos, genfile_path):
"""For the protos from external repository, canonicalize the proto path and the file name.
Returns:
Proto path list and proto source file list.
"""
proto_paths = []
proto_file_names = []
for s in all_protos.to_list():
if s.path.startswith(genfile_path):
repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/")
# handle virtual imports
if file_name.startswith("_virtual_imports"):
repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2])
file_name = file_name.split("/", 2)[-1]
proto_paths.append(genfile_path + "/external/" + repo_name)
proto_file_names.append(file_name)
else:
proto_file_names.append(s.path)
return ([" --proto_path=" + path for path in proto_paths], proto_file_names)
def _get_proto_provider(dep):
"""Get the provider for protocol buffers from a dependnecy.
Necessary because Bazel does not provide the .proto. provider but ProtoInfo
cannot be created from Starlark at the moment.
Returns:
The provider containing information about protocol buffers.
"""
if ProtoInfo in dep:
return dep[ProtoInfo]
elif hasattr(dep, "proto"):
return dep.proto
else:
fail("cannot happen, rule definition requires .proto or ProtoInfo")
def _encode_binary_proto_impl(ctx):
"""Implementation of the encode_binary_proto rule."""
all_protos = depset(
direct = [],
transitive = [_get_proto_provider(dep).transitive_sources for dep in ctx.attr.deps],
)
textpb = ctx.file.input
binarypb = ctx.outputs.output or ctx.actions.declare_file(
textpb.basename.rsplit(".", 1)[0] + ".binarypb",
sibling = textpb,
)
path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path)
# Note: the combination of absolute_paths and proto_path, as well as the exact
# order of gendir before ., is needed for the proto compiler to resolve
# import statements that reference proto files produced by a genrule.
ctx.actions.run_shell(
tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler],
outputs = [binarypb],
command = " ".join(
[
ctx.executable._proto_compiler.path,
"--encode=" + ctx.attr.message_type,
"--proto_path=" + ctx.genfiles_dir.path,
"--proto_path=" + ctx.bin_dir.path,
"--proto_path=.",
] + path_list + file_list +
["<", textpb.path, ">", binarypb.path],
),
mnemonic = "EncodeProto",
)
output_depset = depset([binarypb])
return [DefaultInfo(
files = output_depset,
data_runfiles = ctx.runfiles(transitive_files = output_depset),
)]
_encode_binary_proto = rule(
implementation = _encode_binary_proto_impl,
attrs = {
"_proto_compiler": attr.label(
executable = True,
default = Label(PROTOC),
cfg = "exec",
),
"deps": attr.label_list(
providers = [[ProtoInfo], ["proto"]],
),
"input": attr.label(
mandatory = True,
allow_single_file = True,
),
"message_type": attr.string(
mandatory = True,
),
"output": attr.output(),
},
)
def encode_binary_proto(name, input, message_type, deps, **kwargs):
if type(input) == type("string"):
input_label = input
textproto_srcs = [input]
elif type(input) == type(dict()):
# We cannot accept a select, as macros are unable to manipulate selects.
input_label = select(input)
srcs_dict = dict()
for k, v in input.items():
srcs_dict[k] = [v]
textproto_srcs = select(srcs_dict)
else:
fail("input should be a string or a dict, got %s" % input)
_encode_binary_proto(
name = name,
input = input_label,
message_type = message_type,
deps = deps,
**kwargs
)
def _generate_proto_descriptor_set_impl(ctx):
"""Implementation of the generate_proto_descriptor_set rule."""
all_protos = depset(transitive = [
_get_proto_provider(dep).transitive_sources
for dep in ctx.attr.deps
if ProtoInfo in dep or hasattr(dep, "proto")
])
descriptor = ctx.outputs.output
# Note: the combination of absolute_paths and proto_path, as well as the exact
# order of gendir before ., is needed for the proto compiler to resolve
# import statements that reference proto files produced by a genrule.
ctx.actions.run(
inputs = all_protos,
tools = [ctx.executable._proto_compiler],
outputs = [descriptor],
executable = ctx.executable._proto_compiler,
arguments = [
"--descriptor_set_out=%s" % descriptor.path,
"--proto_path=" + ctx.genfiles_dir.path,
"--proto_path=" + ctx.bin_dir.path,
"--proto_path=.",
] +
[s.path for s in all_protos.to_list()],
mnemonic = "GenerateProtoDescriptor",
)
generate_proto_descriptor_set = rule(
implementation = _generate_proto_descriptor_set_impl,
attrs = {
"_proto_compiler": attr.label(
executable = True,
default = Label(PROTOC),
cfg = "exec",
),
"deps": attr.label_list(
providers = [[ProtoInfo], ["proto"]],
),
},
outputs = {"output": "%{name}.proto.bin"},
)