Internal Change
PiperOrigin-RevId: 547265380
This commit is contained in:
parent
e4ec4d2526
commit
4788fddde9
1
mediapipe/tasks/cc/text/utils/xnn_utils/BUILD
Normal file
1
mediapipe/tasks/cc/text/utils/xnn_utils/BUILD
Normal file
|
@ -0,0 +1 @@
|
|||
# Utilities needed to interacte with XNNPACK.
|
887
mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc
Normal file
887
mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc
Normal file
|
@ -0,0 +1,887 @@
|
|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/source_location.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
#include "third_party/XNNPACK/include/xnnpack.h"
|
||||
#include "util/gtl/stl_logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
namespace {
|
||||
|
||||
// XNNPACK supports broadcasting, this function inferences the output shape
|
||||
// based on input tensor shapes.
|
||||
std::vector<size_t> OutDimsForElementwiseOp(const Tensor& lhs,
|
||||
const Tensor& rhs) {
|
||||
DCHECK(!lhs.dims.empty());
|
||||
DCHECK(!rhs.dims.empty());
|
||||
std::vector<size_t> lhs_dims_rev(lhs.dims.rbegin(), lhs.dims.rend());
|
||||
std::vector<size_t> rhs_dims_rev(rhs.dims.rbegin(), rhs.dims.rend());
|
||||
DCHECK([&]() -> bool {
|
||||
for (size_t i = 0; i < std::min(lhs_dims_rev.size(), rhs_dims_rev.size());
|
||||
++i) {
|
||||
if ((lhs_dims_rev[i] != rhs_dims_rev[i]) && (lhs_dims_rev[i] != 1) &&
|
||||
(rhs_dims_rev[i] != 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}()) << "lhs "
|
||||
<< lhs.dims << " rhs " << rhs.dims;
|
||||
std::vector<size_t> out_dims(
|
||||
std::max(lhs_dims_rev.size(), rhs_dims_rev.size()));
|
||||
for (int i = 0; i < out_dims.size(); ++i) {
|
||||
if (lhs_dims_rev.size() <= i) {
|
||||
out_dims[i] = rhs_dims_rev[i];
|
||||
} else if (rhs_dims_rev.size() <= i) {
|
||||
out_dims[i] = lhs_dims_rev[i];
|
||||
} else {
|
||||
out_dims[i] = lhs_dims_rev[i] == 1 ? rhs_dims_rev[i] : lhs_dims_rev[i];
|
||||
}
|
||||
}
|
||||
return std::vector<size_t>(out_dims.rbegin(), out_dims.rend());
|
||||
}
|
||||
|
||||
// If out_id is invalid, we need to allocate tensor for intermediate result.
|
||||
// Otherwise, set out_id in out_metadata.
|
||||
absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph,
|
||||
uint32_t out_id,
|
||||
Tensor& out_metadata) {
|
||||
RET_CHECK_GT(out_metadata.dims.size(), 0);
|
||||
if (out_id == XNN_INVALID_VALUE_ID) {
|
||||
// The output is intermediate, thus allocate tensor.
|
||||
MP_RETURN_IF_ERROR(out_metadata.DefineAsIntermediateTensor(*subgraph));
|
||||
} else {
|
||||
out_metadata.tensor_id = out_id;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph,
|
||||
Tensor& out_metadata) {
|
||||
return MaybeAllocateIntermediateTensor(subgraph, out_metadata.tensor_id,
|
||||
out_metadata);
|
||||
}
|
||||
|
||||
absl::Status AllocateIntermediateTensor(xnn_subgraph_t subgraph,
|
||||
Tensor& out_metadata) {
|
||||
return MaybeAllocateIntermediateTensor(subgraph, XNN_INVALID_VALUE_ID,
|
||||
out_metadata);
|
||||
}
|
||||
|
||||
// 1.0/jax.nn.softplus(0.0) = 1.442695041
|
||||
// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1])
|
||||
void SoftPlus(size_t cnt, const std::vector<size_t>& query_dims, float* weight,
|
||||
float* scale) {
|
||||
constexpr double r_softplus_0 = 1.442695041;
|
||||
// softplus(x) = np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)
|
||||
// scale = softplus(per_dim_scale) / (sqrt(input.dims[-1]) * softplus(0))
|
||||
const double r_softplus_0_over_sqrt_d =
|
||||
r_softplus_0 / std::sqrt(query_dims.back());
|
||||
for (int i = 0; i < cnt; ++i) {
|
||||
scale[i] = log1p(exp(-abs(weight[i]))) + fmax(weight[i], 0.0f);
|
||||
scale[i] *= r_softplus_0_over_sqrt_d;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build(
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs) {
|
||||
if (!runtime_configs) {
|
||||
runtime_configs = std::make_unique<RuntimeConfigs>();
|
||||
runtime_configs->xnn_num_threads = 1;
|
||||
runtime_configs->xnn_profile = false;
|
||||
}
|
||||
VLOG(2) << "XnnGraphBuilder::Build() building...";
|
||||
auto build_begin = absl::Now();
|
||||
RET_CHECK_EQ(xnn_status_success, xnn_initialize(nullptr));
|
||||
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> output_tensors;
|
||||
{
|
||||
uint32_t cnt = input_tensors_.size();
|
||||
for (auto& t : interm_tensors_) {
|
||||
if (t->is_output_tensor) {
|
||||
RET_CHECK_EQ(t->tensor_id, XNN_INVALID_VALUE_ID);
|
||||
t->tensor_id = cnt++;
|
||||
output_tensors.insert(t);
|
||||
}
|
||||
}
|
||||
for (auto& t : output_tensors) {
|
||||
interm_tensors_.erase(t);
|
||||
}
|
||||
for (auto& t : rope_weigths_) {
|
||||
interm_tensors_.erase(t);
|
||||
t->tensor_id = cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
xnn_subgraph_t subgraph_ptr = nullptr;
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_create_subgraph(
|
||||
/*external_value_ids=*/input_tensors_.size() +
|
||||
output_tensors.size() + rope_weigths_.size(),
|
||||
/*flags=*/0, &subgraph_ptr));
|
||||
RET_CHECK_NE(subgraph_ptr, nullptr);
|
||||
|
||||
XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph};
|
||||
|
||||
for (auto& input : input_tensors_) {
|
||||
MP_RETURN_IF_ERROR(input->DefineAsInput(*subgraph));
|
||||
}
|
||||
for (auto& output : output_tensors) {
|
||||
MP_RETURN_IF_ERROR(output->DefineAsOutput(*subgraph));
|
||||
}
|
||||
{
|
||||
for (auto& t : rope_weigths_) {
|
||||
MP_RETURN_IF_ERROR(t->DefineRope(*subgraph));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& [loc, step] : build_steps_) {
|
||||
if (auto s = step(subgraph.get()); !s.ok()) {
|
||||
s.AddSourceLocation(loc);
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
XnnGraph result(std::move(subgraph), std::move(runtime_configs));
|
||||
result.input_tensors_ = std::move(input_tensors_);
|
||||
result.output_tensors_ = std::move(output_tensors);
|
||||
result.interm_tensors_ = std::move(interm_tensors_);
|
||||
|
||||
VLOG(2) << "XnnGraphBuilder::Build() creating runtime...";
|
||||
auto create_begin = absl::Now();
|
||||
MP_RETURN_IF_ERROR(result.CreateRuntime());
|
||||
VLOG(2) << "XnnGraphBuilder::Build() setting up runtime...";
|
||||
auto setup_begin = absl::Now();
|
||||
MP_RETURN_IF_ERROR(result.SetupRuntime());
|
||||
|
||||
auto end = absl::Now();
|
||||
VLOG(2) << "XnnGraphBuilder::Build() done build, Total " << end - build_begin
|
||||
<< ", create runtime " << setup_begin - create_begin
|
||||
<< ", setup runtime " << end - setup_begin;
|
||||
return std::make_unique<XnnGraph>(std::move(result));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewInput(
|
||||
Tensor::DimsType dims, absl::SourceLocation loc) {
|
||||
auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
|
||||
t->AllocateBufferIfNeeded();
|
||||
t->tensor_id = input_tensors_.size();
|
||||
input_tensors_.insert(t);
|
||||
return t;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewWeight(
|
||||
absl::string_view file_path, Tensor::DimsType dims,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto t, NewWeight(std::move(dims)));
|
||||
MP_RETURN_IF_ERROR(t->LoadFromFile(file_path));
|
||||
return t;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewWeight(
|
||||
Tensor::DimsType dims, absl::SourceLocation loc) {
|
||||
auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
|
||||
NewWeight(t, loc);
|
||||
return t;
|
||||
}
|
||||
|
||||
void XnnGraphBuilder::NewWeight(std::shared_ptr<Tensor> t,
|
||||
absl::SourceLocation loc) {
|
||||
build_steps_.push_back(
|
||||
{loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
if (interm_tensors_.contains(t)) {
|
||||
MP_RETURN_IF_ERROR(t->DefineWeight(*subgraph));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
interm_tensors_.insert(t);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::IntermediateTensor(
|
||||
Tensor::DimsType dims, absl::SourceLocation loc) {
|
||||
auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
// Could be moved to output tensors, thus need check.
|
||||
if (interm_tensors_.contains(t)) {
|
||||
return AllocateIntermediateTensor(subgraph, *t);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
interm_tensors_.insert(t);
|
||||
return t;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Reshape(
|
||||
std::shared_ptr<Tensor> input, Tensor::DimsType new_dims,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims)));
|
||||
RET_CHECK_EQ(input->num_elements, output->num_elements)
|
||||
<< "otherwise reshape does not make sense.";
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
|
||||
subgraph, output->tensor_id, *output));
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_static_reshape(
|
||||
subgraph, output->dims.size(), output->dims.data(),
|
||||
input->tensor_id, output->tensor_id, /*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::FullConn(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
std::shared_ptr<Tensor> bias, FullConnParams params,
|
||||
absl::SourceLocation loc) {
|
||||
const auto& input_dim = input->dims;
|
||||
const auto& weight_dim = weight->dims;
|
||||
DCHECK_GT(input_dim.size(), 1);
|
||||
DCHECK_GE(weight_dim.size(), 2);
|
||||
if (weight_dim.size() == 3) {
|
||||
RET_CHECK_EQ(weight_dim[0], 1);
|
||||
} else if (weight_dim.size() == 4) {
|
||||
RET_CHECK_EQ(weight_dim[0], 1);
|
||||
RET_CHECK_EQ(weight_dim[1], 1);
|
||||
}
|
||||
if (bias) {
|
||||
RET_CHECK_LE(bias->dims.size(), 1);
|
||||
}
|
||||
|
||||
Tensor::DimsType out_dims = input_dim;
|
||||
// Not considering reshape 2D
|
||||
if (params.transpose) {
|
||||
RET_CHECK_EQ(weight_dim.size(), 2) << "otherwise change following line";
|
||||
RET_CHECK_EQ(input_dim.back(), *(weight_dim.end() - 2));
|
||||
out_dims.back() = weight_dim.back();
|
||||
} else {
|
||||
RET_CHECK_EQ(input_dim.back(), weight_dim.back());
|
||||
out_dims.pop_back();
|
||||
for (size_t i = 0; i < weight_dim.size() - 1; ++i) {
|
||||
// NHD . BTD -> NHBT
|
||||
out_dims.push_back(weight_dim[i]);
|
||||
}
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(out_dims)));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, input, weight, bias, params,
|
||||
output](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
|
||||
subgraph, output->tensor_id, *output));
|
||||
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_fully_connected(
|
||||
subgraph, params.out_min, params.out_max, input->tensor_id,
|
||||
weight->tensor_id,
|
||||
bias ? bias->tensor_id : XNN_INVALID_VALUE_ID,
|
||||
output->tensor_id,
|
||||
/*flags=*/params.transpose ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0));
|
||||
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Permute(
|
||||
std::shared_ptr<Tensor> input, Tensor::DimsType permute,
|
||||
absl::SourceLocation loc) {
|
||||
RET_CHECK_EQ(input->dims.size(), permute.size());
|
||||
const auto& old_dims = input->dims;
|
||||
std::vector<size_t> new_dims;
|
||||
for (size_t i = 0; i < permute.size(); ++i) {
|
||||
new_dims.push_back(old_dims[permute[i]]);
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims)));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, permute, input, output](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_static_transpose(
|
||||
subgraph, permute.size(), permute.data(),
|
||||
input->tensor_id, output->tensor_id, /*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Square(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
|
||||
subgraph, output->tensor_id, *output));
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_square(subgraph, input->tensor_id, output->tensor_id,
|
||||
/*flags=*/0));
|
||||
return absl::Status();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Softmax(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
|
||||
subgraph, output->tensor_id, *output));
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_softmax(subgraph, input->tensor_id, output->tensor_id,
|
||||
/*flags=*/0));
|
||||
return absl::Status();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::SquareRoot(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
|
||||
subgraph, output->tensor_id, *output));
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_square_root(subgraph, input->tensor_id,
|
||||
output->tensor_id,
|
||||
/*flags=*/0));
|
||||
return absl::Status();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::AvgLastDim(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto before_reshape,
|
||||
IntermediateTensor(Tensor::DimsType{input->dims.begin(),
|
||||
input->dims.end() - 1}));
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, input, before_reshape](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
|
||||
subgraph, before_reshape->tensor_id, *before_reshape));
|
||||
size_t reduction_axis = input->dims.size() - 1;
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_static_mean(subgraph, 1, &reduction_axis,
|
||||
input->tensor_id, before_reshape->tensor_id,
|
||||
/*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
Tensor::DimsType new_dims = input->dims;
|
||||
new_dims.back() = 1;
|
||||
return Reshape(before_reshape, std::move(new_dims));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rms(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto sqr_out, Square(input, loc));
|
||||
|
||||
ASSIGN_OR_RETURN(auto mean_out, AvgLastDim(sqr_out, loc));
|
||||
|
||||
return SquareRoot(mean_out, loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::RmsNorm(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto rms_out, Rms(input));
|
||||
|
||||
ASSIGN_OR_RETURN(auto clamped_rms, Clamp(rms_out, {.out_min = 1e-6}));
|
||||
|
||||
// div_out = input / rms
|
||||
ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms));
|
||||
|
||||
// div_out * (1 + scale) = div_out + div_out * scale
|
||||
ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale));
|
||||
|
||||
return ElementAdd(div_out, normed_div_out);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
|
||||
std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
|
||||
MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
|
||||
|
||||
return ElementAdd(lhs, rhs_tensor, params, loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
|
||||
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
|
||||
ClampParams params, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output,
|
||||
IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, lhs, rhs, output,
|
||||
params](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_add2(subgraph, params.out_min, params.out_max,
|
||||
lhs->tensor_id, rhs->tensor_id,
|
||||
output->tensor_id, /*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementMul(
|
||||
std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
|
||||
MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
|
||||
|
||||
return ElementMul(lhs, rhs_tensor, params, loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementMul(
|
||||
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
|
||||
ClampParams params, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output,
|
||||
IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, lhs, rhs, output,
|
||||
params](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_multiply2(subgraph, params.out_min, params.out_max,
|
||||
lhs->tensor_id, rhs->tensor_id,
|
||||
output->tensor_id, /*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementDiv(
|
||||
std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
|
||||
MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
|
||||
|
||||
return ElementDiv(lhs, rhs_tensor, params, loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementDiv(
|
||||
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
|
||||
ClampParams params, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output,
|
||||
IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, lhs, rhs, output,
|
||||
params](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_divide(subgraph, params.out_min, params.out_max,
|
||||
lhs->tensor_id, rhs->tensor_id,
|
||||
output->tensor_id, /*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// TODO: write an op?
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::PerDimScale(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> per_dim_scale,
|
||||
absl::SourceLocation loc) {
|
||||
// input: B T N H
|
||||
// 1/softplus(0) = 1.442695041
|
||||
// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1])
|
||||
// query = query * scale
|
||||
const auto& input_dim = input->dims;
|
||||
DCHECK_GE(input_dim.size(), 1);
|
||||
const size_t H = input_dim.back();
|
||||
|
||||
if (!per_dim_scale_cache_.contains(H) ||
|
||||
!per_dim_scale_cache_[H].contains(per_dim_scale.get())) {
|
||||
ASSIGN_OR_RETURN(auto cached_pds, NewWeight(per_dim_scale->dims));
|
||||
|
||||
auto* pds_in = static_cast<float*>(per_dim_scale->Data());
|
||||
std::vector<float> pds_scaled(per_dim_scale->num_elements);
|
||||
SoftPlus(per_dim_scale->num_elements, input_dim, pds_in, pds_scaled.data());
|
||||
MP_RETURN_IF_ERROR(cached_pds->LoadFromVec(std::move(pds_scaled)));
|
||||
per_dim_scale_cache_[H][per_dim_scale.get()] = cached_pds;
|
||||
}
|
||||
|
||||
return ElementMul(input, per_dim_scale_cache_[H][per_dim_scale.get()]);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rope(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos,
|
||||
absl::SourceLocation loc) {
|
||||
// TODO: seg_pos should not be weight.
|
||||
rope_weigths_.insert(segment_pos);
|
||||
|
||||
const auto& input_dim = input->dims;
|
||||
const auto& segment_pos_dim = segment_pos->dims;
|
||||
// B T N H
|
||||
RET_CHECK_EQ(input_dim.size(), 4) << "xnn requirement";
|
||||
// S H
|
||||
RET_CHECK_EQ(segment_pos_dim.size(), 2) << "xnn requirement";
|
||||
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input_dim));
|
||||
|
||||
const auto input_seq_size = input_dim[1];
|
||||
RET_CHECK_LE(input_seq_size, segment_pos_dim[0]);
|
||||
const auto head_dim_H = input_dim[3];
|
||||
RET_CHECK_EQ(head_dim_H, segment_pos_dim[1]);
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, input, output, segment_pos,
|
||||
input_seq_size](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_rope(subgraph, input_seq_size, input->tensor_id,
|
||||
segment_pos->tensor_id, output->tensor_id,
|
||||
/*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::BatchMatMul(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
FullConnParams params, absl::SourceLocation loc) {
|
||||
const auto& lhs_dim = input->dims;
|
||||
const auto& rhs_dim = weight->dims;
|
||||
|
||||
// [B, N, T, H] . [B, N, S, H], N == 12, B == 1
|
||||
DCHECK_EQ(lhs_dim.size(), 4);
|
||||
DCHECK_EQ(rhs_dim.size(), 4);
|
||||
DCHECK_EQ(lhs_dim.back(), rhs_dim.back());
|
||||
DCHECK_EQ(lhs_dim.back(), rhs_dim.back());
|
||||
constexpr size_t num_slices = 12;
|
||||
DCHECK_EQ(lhs_dim[1], num_slices);
|
||||
DCHECK_EQ(rhs_dim[1], num_slices);
|
||||
const size_t S = rhs_dim[2];
|
||||
const size_t T = lhs_dim[2];
|
||||
const size_t batch_size = lhs_dim[0] * lhs_dim[1];
|
||||
DCHECK_EQ(batch_size, rhs_dim[0] * rhs_dim[1]);
|
||||
DCHECK_EQ(batch_size, 12);
|
||||
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor({1, 12, T, S}));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [input, output, weight](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_batch_matrix_multiply(
|
||||
subgraph, input->tensor_id, weight->tensor_id,
|
||||
output->tensor_id, /*flags=*/0));
|
||||
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Tanh(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_tanh(subgraph, input->tensor_id,
|
||||
output->tensor_id, /*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::CapTanh(
|
||||
std::shared_ptr<Tensor> input, float cap, absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto div, ElementDiv(input, cap));
|
||||
ASSIGN_OR_RETURN(auto tanh, Tanh(div));
|
||||
return ElementMul(tanh, cap);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::DotAttention(
|
||||
std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
|
||||
std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
|
||||
std::shared_ptr<Tensor> per_dim_scale, absl::SourceLocation loc) {
|
||||
// BTNH
|
||||
ASSIGN_OR_RETURN(auto query_after_scale,
|
||||
PerDimScale(query_proj, per_dim_scale));
|
||||
|
||||
// Dot similarity
|
||||
// BTNH -> BNTH
|
||||
ASSIGN_OR_RETURN(auto query_permuted,
|
||||
Permute(query_after_scale, {0, 2, 1, 3}));
|
||||
// BSNH -> BNSH
|
||||
ASSIGN_OR_RETURN(auto key_permuted, Permute(key_proj, {0, 2, 1, 3}));
|
||||
// einsum(BNTH.BNSH -> BNTS)
|
||||
ASSIGN_OR_RETURN(auto logits, BatchMatMul(query_permuted, key_permuted));
|
||||
|
||||
// Cap, mask
|
||||
ASSIGN_OR_RETURN(auto cap_logits, CapTanh(logits, 50));
|
||||
ASSIGN_OR_RETURN(auto padded_logits, ElementAdd(atten_mask, cap_logits));
|
||||
ASSIGN_OR_RETURN(auto probs, Softmax(padded_logits));
|
||||
ASSIGN_OR_RETURN(auto value_permuted, Permute(value_proj, {0, 2, 3, 1}));
|
||||
|
||||
// Outcome
|
||||
// BNTS.BNHS -> BNTH
|
||||
ASSIGN_OR_RETURN(auto outcome_before_permute,
|
||||
BatchMatMul(probs, value_permuted));
|
||||
// [B, N, T, H] -> BTNH
|
||||
return Permute(outcome_before_permute, {0, 2, 1, 3});
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::SelfAttentionProj(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
absl::SourceLocation loc) {
|
||||
const auto& input_dim = input->dims;
|
||||
const auto& weight_dim = weight->dims;
|
||||
size_t N = 0, H = 0;
|
||||
RET_CHECK_EQ(input_dim.size(), 3) << "BTD";
|
||||
|
||||
std::optional<int> reshaped_N =
|
||||
weight->GetMetadata(kKeySelfAttentionReshapedWeight);
|
||||
RET_CHECK(reshaped_N && *reshaped_N)
|
||||
<< "We rely on " << kKeySelfAttentionReshapedWeight << " to get N";
|
||||
RET_CHECK_EQ(weight_dim.size(), 2) << "NH,D";
|
||||
N = *reshaped_N;
|
||||
H = weight_dim[0] / N;
|
||||
|
||||
// out: B,T,NH
|
||||
ASSIGN_OR_RETURN(auto proj, MatMul(input, weight));
|
||||
|
||||
// B,T,NH -> B,T,N,H
|
||||
return Reshape(proj, {input_dim[0], input_dim[1], N, H});
|
||||
}
|
||||
|
||||
absl::Status XnnGraph::CreateRuntime() {
|
||||
RET_CHECK_EQ(runtime_.get(), nullptr);
|
||||
xnn_runtime_t runtime_ptr = nullptr;
|
||||
uint32_t flags = 0;
|
||||
if (runtime_configs_->xnn_profile) {
|
||||
flags |= XNN_FLAG_BASIC_PROFILING;
|
||||
|
||||
if (!runtime_configs_->xnn_profile_csv.empty()) {
|
||||
MP_RETURN_IF_ERROR(file::SetContents(runtime_configs_->xnn_profile_csv,
|
||||
"node_id; time(us); op_name\n",
|
||||
file::Defaults()));
|
||||
}
|
||||
}
|
||||
pthreadpool_t threadpool =
|
||||
pthreadpool_create(runtime_configs_->xnn_num_threads);
|
||||
threadpool_ = XnnThreadpoolPtr{threadpool, pthreadpool_destroy};
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_create_runtime_v2(owned_subgraph_.get(), threadpool, flags,
|
||||
&runtime_ptr));
|
||||
RET_CHECK_NE(runtime_ptr, nullptr);
|
||||
runtime_ = XnnRuntimePtr{runtime_ptr, xnn_delete_runtime};
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status XnnGraph::SetupRuntime() {
|
||||
{
|
||||
VLOG(3) << "input size " << input_tensors_.size();
|
||||
VLOG(3) << "output size " << output_tensors_.size();
|
||||
VLOG(3) << "rope size " << rope_weigths_.size();
|
||||
externals_.clear();
|
||||
// Init external
|
||||
for (const auto& input : input_tensors_) {
|
||||
VLOG(3) << "input id " << input->tensor_id;
|
||||
externals_.push_back(xnn_external_value{input->tensor_id, input->Data()});
|
||||
}
|
||||
for (const auto& output : output_tensors_) {
|
||||
VLOG(3) << "output id " << output->tensor_id;
|
||||
externals_.push_back(
|
||||
xnn_external_value{output->tensor_id, output->Data()});
|
||||
}
|
||||
for (const auto& t : rope_weigths_) {
|
||||
VLOG(3) << "rope id " << t->tensor_id;
|
||||
}
|
||||
}
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status XnnGraph::Run() {
|
||||
RET_CHECK(runtime_);
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success, xnn_invoke_runtime(runtime_.get()));
|
||||
|
||||
if (runtime_configs_->xnn_profile) {
|
||||
size_t required_size = 0;
|
||||
|
||||
// xnn_get_runtime_profiling_info is called twice. The first time it sets
|
||||
// required_size to the required size of the buffer to store the result and
|
||||
// returns xnn_status_out_of_memory. The second time it writes the result to
|
||||
// the buffer provided that the buffer is large enough and returns
|
||||
// xnn_status_success.
|
||||
xnn_status status = xnn_get_runtime_profiling_info(
|
||||
runtime_.get(), xnn_profile_info_operator_name, /*param_value_size*/ 0,
|
||||
/*param_value*/ nullptr, &required_size);
|
||||
std::vector<char> operator_names;
|
||||
if (status == xnn_status_out_of_memory) {
|
||||
operator_names.resize(required_size);
|
||||
status = xnn_get_runtime_profiling_info(
|
||||
runtime_.get(), xnn_profile_info_operator_name, operator_names.size(),
|
||||
operator_names.data(), &required_size);
|
||||
}
|
||||
RET_CHECK_EQ(status, xnn_status_success);
|
||||
size_t num_operators;
|
||||
status = xnn_get_runtime_profiling_info(
|
||||
runtime_.get(), xnn_profile_info_num_operators, sizeof(num_operators),
|
||||
&num_operators, &required_size);
|
||||
RET_CHECK_EQ(status, xnn_status_success);
|
||||
status = xnn_get_runtime_profiling_info(
|
||||
runtime_.get(), xnn_profile_info_operator_timing,
|
||||
/*param_value_size*/ 0,
|
||||
/*param_value*/ nullptr, &required_size);
|
||||
std::vector<uint64_t> operator_timings;
|
||||
if (status == xnn_status_out_of_memory) {
|
||||
operator_timings.resize(required_size / sizeof(uint64_t));
|
||||
status = xnn_get_runtime_profiling_info(
|
||||
runtime_.get(), xnn_profile_info_operator_timing,
|
||||
operator_timings.size() * sizeof(uint64_t), operator_timings.data(),
|
||||
&required_size);
|
||||
}
|
||||
RET_CHECK_EQ(status, xnn_status_success);
|
||||
const char* operator_name = nullptr;
|
||||
size_t name_len = 0;
|
||||
std::stringstream ss;
|
||||
for (size_t node_index = 0; node_index < num_operators; ++node_index) {
|
||||
operator_name = &operator_names[name_len];
|
||||
name_len += strlen(operator_name) + 1;
|
||||
VLOG(2) << "XnnGraphBuilder::Profile() node_index: " << node_index
|
||||
<< ", time: " << operator_timings[node_index] << " us, "
|
||||
<< operator_name << "\n";
|
||||
if (!runtime_configs_->xnn_profile_csv.empty()) {
|
||||
// Use ';' instead of ',' because operator_name contains comma.
|
||||
ss << node_index << "; " << operator_timings[node_index] << "; "
|
||||
<< operator_name << "\n";
|
||||
}
|
||||
}
|
||||
if (!runtime_configs_->xnn_profile_csv.empty()) {
|
||||
MP_RETURN_IF_ERROR(file::AppendStringToFile(
|
||||
runtime_configs_->xnn_profile_csv, ss.str(), file::Defaults()));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Clamp(
|
||||
std::shared_ptr<Tensor> input, ClampParams params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
|
||||
|
||||
build_steps_.push_back(
|
||||
{loc,
|
||||
[this, input, output, params](xnn_subgraph_t subgraph) -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
|
||||
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_clamp(subgraph, params.out_min, params.out_max,
|
||||
input->tensor_id, output->tensor_id,
|
||||
/*flags=*/0));
|
||||
return absl::OkStatus();
|
||||
}});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Gelu(
|
||||
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
|
||||
// x^2
|
||||
ASSIGN_OR_RETURN(auto sqr_out, Square(input));
|
||||
|
||||
// 0.044715 * x^2
|
||||
ASSIGN_OR_RETURN(auto sqr_4471, ElementMul(sqr_out, 0.044715));
|
||||
|
||||
// 1 + 0.044715 * x^2
|
||||
ASSIGN_OR_RETURN(auto sqr_4471_1, ElementAdd(sqr_4471, 1.0f));
|
||||
|
||||
// x + 0.044715 * x^3
|
||||
ASSIGN_OR_RETURN(auto x_cube_4471, ElementMul(sqr_4471_1, input));
|
||||
|
||||
constexpr float sqrt_2_over_pi = 0.7978845608;
|
||||
ASSIGN_OR_RETURN(auto sqrt_2_over_pi_x_cube_4471,
|
||||
ElementMul(x_cube_4471, sqrt_2_over_pi));
|
||||
|
||||
// tanh(x + 0.044715 * x^3)
|
||||
ASSIGN_OR_RETURN(auto tanh_x_cube_4471, Tanh(sqrt_2_over_pi_x_cube_4471));
|
||||
|
||||
// 1 + tanh(x + 0.044715 * x^3)
|
||||
ASSIGN_OR_RETURN(auto tanh_x_cube_4471_1, ElementAdd(tanh_x_cube_4471, 1.0f));
|
||||
|
||||
// 0.5 * (1 + [tanh(x + 0.044715 * x^3)])
|
||||
ASSIGN_OR_RETURN(auto cdf, ElementMul(tanh_x_cube_4471_1, 0.5));
|
||||
|
||||
return ElementMul(input, cdf);
|
||||
}
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
288
mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h
Normal file
288
mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h
Normal file
|
@ -0,0 +1,288 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
|
||||
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/source_location.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
#include "third_party/XNNPACK/include/xnnpack.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
using XnnSubgraphPtr =
|
||||
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)>;
|
||||
using XnnRuntimePtr =
|
||||
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>;
|
||||
using XnnThreadpoolPtr =
|
||||
std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)>;
|
||||
|
||||
struct ClampParams {
|
||||
float out_min = -std::numeric_limits<float>::infinity();
|
||||
float out_max = std::numeric_limits<float>::infinity();
|
||||
};
|
||||
|
||||
struct FullConnParams : public ClampParams {
|
||||
bool transpose = false;
|
||||
};
|
||||
|
||||
struct RuntimeConfigs {
|
||||
bool xnn_profile;
|
||||
std::string xnn_profile_csv;
|
||||
size_t xnn_num_threads;
|
||||
};
|
||||
|
||||
class XnnGraph;
|
||||
|
||||
// XnnGraphBuilder is used to construct XnnGraph (through Build()). Once a
|
||||
// XnnGraph is constructed, it can run for multiple times.
|
||||
class XnnGraphBuilder {
|
||||
public:
|
||||
static constexpr absl::string_view kKeySelfAttentionReshapedWeight{
|
||||
"self_attention_reshaped_weight_N"};
|
||||
|
||||
explicit XnnGraphBuilder(xnn_datatype data_type = xnn_datatype_fp32)
|
||||
: data_type_(data_type) {}
|
||||
virtual ~XnnGraphBuilder() = default;
|
||||
|
||||
absl::StatusOr<std::unique_ptr<XnnGraph>> Build(
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
|
||||
|
||||
// New input or output tensor.
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> NewInput(
|
||||
Tensor::DimsType dims,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
// New static weight, populate value before Build()
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> NewWeight(
|
||||
Tensor::DimsType dims,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> NewWeight(
|
||||
absl::string_view file_path, Tensor::DimsType dims,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
void NewWeight(std::shared_ptr<Tensor> t,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
// Element wise square.
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Square(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> SquareRoot(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Gelu(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Clamp(
|
||||
std::shared_ptr<Tensor> input, ClampParams params,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Tanh(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
// logits = cap * jnp.tanh(logits / cap)
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> CapTanh(
|
||||
std::shared_ptr<Tensor> input, float cap,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
// Average over last dimension, keep num of dims same.
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> AvgLastDim(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Rms(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> RmsNorm(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Reshape(
|
||||
std::shared_ptr<Tensor> input, Tensor::DimsType new_dims,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Permute(
|
||||
std::shared_ptr<Tensor> input, Tensor::DimsType permute,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
// input: [B * I]
|
||||
// filter: [O * I], [I * O] if transpose
|
||||
// return: [B * O]
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> MatMul(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current()) {
|
||||
return MatMul(input, weight, FullConnParams(), loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> MatMul(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
FullConnParams params,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current()) {
|
||||
return FullConn(input, weight, nullptr, params, loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> BatchMatMul(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
FullConnParams params = FullConnParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> FullConn(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
std::shared_ptr<Tensor> bias,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current()) {
|
||||
return FullConn(input, weight, bias, FullConnParams(), loc);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> FullConn(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
std::shared_ptr<Tensor> bias, FullConnParams params,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Softmax(
|
||||
std::shared_ptr<Tensor> input,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionProj(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ElementAdd(
|
||||
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
|
||||
ClampParams params = ClampParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ElementAdd(
|
||||
std::shared_ptr<Tensor> lhs, float rhs,
|
||||
ClampParams params = ClampParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ElementMul(
|
||||
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
|
||||
ClampParams params = ClampParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ElementMul(
|
||||
std::shared_ptr<Tensor> lhs, float rhs,
|
||||
ClampParams params = ClampParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ElementDiv(
|
||||
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
|
||||
ClampParams params = ClampParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ElementDiv(
|
||||
std::shared_ptr<Tensor> lhs, float rhs,
|
||||
ClampParams params = ClampParams(),
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Rope(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> PerDimScale(
|
||||
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> per_dim_scale,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> DotAttention(
|
||||
std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
|
||||
std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
|
||||
std::shared_ptr<Tensor> per_dim_scale,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
protected:
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> IntermediateTensor(
|
||||
Tensor::DimsType dims,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
const xnn_datatype data_type_;
|
||||
|
||||
std::vector<std::pair<absl::SourceLocation,
|
||||
std::function<absl::Status(xnn_subgraph_t)>>>
|
||||
build_steps_;
|
||||
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> input_tensors_;
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
|
||||
|
||||
// TODO: fix this.
|
||||
// This is sort of bug that the weights used for rope has to be defined with
|
||||
// EXTERNAL flag, but with id out of the external range.
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weigths_;
|
||||
|
||||
// Caches
|
||||
absl::flat_hash_map<
|
||||
size_t /*dim*/,
|
||||
absl::flat_hash_map<const Tensor* /*scale*/, std::shared_ptr<Tensor>>>
|
||||
per_dim_scale_cache_;
|
||||
};
|
||||
|
||||
class XnnGraph {
|
||||
public:
|
||||
XnnGraph(XnnSubgraphPtr subgraph,
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs)
|
||||
: owned_subgraph_(std::move(subgraph)),
|
||||
runtime_configs_(std::move(runtime_configs)) {
|
||||
DCHECK(runtime_configs_);
|
||||
}
|
||||
XnnGraph(XnnGraph&& other) = default;
|
||||
virtual ~XnnGraph() = default;
|
||||
|
||||
// xnn_subgraph should be created with same size.
|
||||
virtual absl::Status Run();
|
||||
|
||||
protected:
|
||||
friend class XnnGraphBuilder;
|
||||
|
||||
absl::Status CreateRuntime();
|
||||
absl::Status SetupRuntime();
|
||||
|
||||
XnnSubgraphPtr owned_subgraph_;
|
||||
|
||||
absl::flat_hash_map<size_t, Tensor> avg_cache_;
|
||||
absl::flat_hash_map<size_t, Tensor> cap_tanh_cache_;
|
||||
|
||||
// Runtime
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs_;
|
||||
XnnRuntimePtr runtime_{nullptr, xnn_delete_runtime};
|
||||
std::vector<xnn_external_value> externals_;
|
||||
|
||||
XnnThreadpoolPtr threadpool_{nullptr, pthreadpool_destroy};
|
||||
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> input_tensors_;
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> output_tensors_;
|
||||
// TODO: see above
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weigths_;
|
||||
|
||||
absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
|
||||
};
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
|
475
mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc
Normal file
475
mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc
Normal file
|
@ -0,0 +1,475 @@
|
|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/text_generator/calculators/preprocessor_util.h"
|
||||
#include "mediapipe/tasks/cc/text/text_generator/calculators/sampler_util.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
#include "util/gtl/stl_logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
namespace {
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ApplyFinalProj(
|
||||
std::shared_ptr<Tensor> inter_layer, const UlmWeights& weights,
|
||||
XnnGraphBuilder& builder) {
|
||||
return builder.FullConn(inter_layer, weights.softmax_linear,
|
||||
weights.softmax_bias);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class OneTokenUlm : public Ulm {
|
||||
public:
|
||||
OneTokenUlm(std::unique_ptr<Ulm> full_ulm, XnnGraph&& other)
|
||||
: Ulm(std::move(other)), full_ulm_(std::move(full_ulm)) {}
|
||||
~OneTokenUlm() override = default;
|
||||
|
||||
absl::Status InitInputTokens(const std::vector<int>& input_ids) override {
|
||||
prev_ids_ = input_ids;
|
||||
MP_RETURN_IF_ERROR(full_ulm_->InitInputTokens(input_ids));
|
||||
// prev_id.size - 1 is the output.
|
||||
return full_ulm_->Run();
|
||||
}
|
||||
|
||||
absl::Status GetNextToken(std::vector<int>* output_ids) override {
|
||||
size_t decode_step = prev_ids_.size() - 1;
|
||||
VLOG(2) << "Decode step " << decode_step;
|
||||
|
||||
if (decode_step == ulm_params_.seq_size_T - 1) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("Hit max sequence length ", ulm_params_.seq_size_T));
|
||||
}
|
||||
|
||||
transformer_input_->Borrow(
|
||||
full_ulm_->transformer_input_->Slice(1, decode_step));
|
||||
atten_masks_->Borrow(full_ulm_->atten_masks_->Slice(0, decode_step));
|
||||
MP_RETURN_IF_ERROR(segment_pos_->LoadFromBuffer(
|
||||
full_ulm_->segment_pos_->Slice(0, decode_step)->Data()));
|
||||
for (auto& kv_cache : kv_cache_) {
|
||||
DCHECK(kv_cache.k_slice);
|
||||
DCHECK(kv_cache.v_slice);
|
||||
kv_cache.k_slice->Borrow(kv_cache.k_cache->Slice(1, decode_step));
|
||||
kv_cache.v_slice->Borrow(kv_cache.v_cache->Slice(1, decode_step));
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(SetupRuntime());
|
||||
MP_RETURN_IF_ERROR(Run());
|
||||
|
||||
RET_CHECK(logits_output_);
|
||||
DCHECK_EQ(logits_output_->num_elements, ulm_params_.voc_size_V);
|
||||
|
||||
ASSIGN_OR_RETURN(*output_ids,
|
||||
mediapipe::SampleNextToken(
|
||||
logits_output_->DataAs<float>(),
|
||||
/*batch_size=*/1,
|
||||
/*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10,
|
||||
/*top_p=*/1, /*temperature=*/-1));
|
||||
RET_CHECK_EQ(output_ids->size(), 1);
|
||||
prev_ids_.push_back(output_ids->at(0));
|
||||
|
||||
return GetTokenEmbedding(
|
||||
*output_ids,
|
||||
pos_embedding_data_->Slice({decode_step + 1, 0})->DataAs<float>(),
|
||||
full_ulm_->transformer_input_->Slice({0, decode_step + 1, 0})
|
||||
->DataAs<float>());
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<Ulm> full_ulm_;
|
||||
};
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::SelfAttentionExcludeNorm(
|
||||
std::shared_ptr<Tensor> input, SelfAttentionArgs args,
|
||||
const SelfAttentionWeights& sa_weights, absl::SourceLocation loc) {
|
||||
// [B, 1|T, N, H]
|
||||
ASSIGN_OR_RETURN(auto k_proj, SelfAttentionProj(input, sa_weights.k_weight));
|
||||
ASSIGN_OR_RETURN(auto q_proj, SelfAttentionProj(input, sa_weights.q_weight));
|
||||
ASSIGN_OR_RETURN(auto v_proj, SelfAttentionProj(input, sa_weights.v_weight));
|
||||
|
||||
ASSIGN_OR_RETURN(auto query_proj_after_rope, Rope(q_proj, args.segment_pos));
|
||||
ASSIGN_OR_RETURN(auto key_proj_after_rope, Rope(k_proj, args.segment_pos));
|
||||
|
||||
if (args.cache) {
|
||||
RET_CHECK(args.cache->k_cache);
|
||||
RET_CHECK(args.cache->v_cache);
|
||||
// When cache is provided, there are 2 cases:
|
||||
if (*(input->dims.end() - 2) != 1) {
|
||||
// Building a normal graph, which is used to initialize cache.
|
||||
key_proj_after_rope->Borrow(args.cache->k_cache).MarkOutput();
|
||||
v_proj->Borrow(args.cache->v_cache).MarkOutput();
|
||||
} else {
|
||||
// Building a one-token graph, which consumes initialized cache.
|
||||
key_proj_after_rope->MarkOutput();
|
||||
args.cache->k_slice = key_proj_after_rope;
|
||||
v_proj->MarkOutput();
|
||||
args.cache->v_slice = v_proj;
|
||||
|
||||
ASSIGN_OR_RETURN(key_proj_after_rope,
|
||||
NewInput(args.cache->k_cache->dims));
|
||||
key_proj_after_rope->Borrow(args.cache->k_cache);
|
||||
ASSIGN_OR_RETURN(v_proj, NewInput(args.cache->v_cache->dims));
|
||||
v_proj->Borrow(args.cache->v_cache);
|
||||
}
|
||||
}
|
||||
|
||||
// encoded, [B, 1|T, N, H]
|
||||
ASSIGN_OR_RETURN(
|
||||
auto kqv_merged,
|
||||
DotAttention(query_proj_after_rope, key_proj_after_rope, v_proj,
|
||||
args.atten_mask, sa_weights.per_dim_scale));
|
||||
|
||||
const size_t B = kqv_merged->dims[0];
|
||||
const size_t T_or_1 = kqv_merged->dims[1];
|
||||
const size_t NH = kqv_merged->num_elements / (B * T_or_1);
|
||||
ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, T_or_1, NH}));
|
||||
|
||||
return MatMul(outcome_reshaped, sa_weights.post_proj_weight,
|
||||
{.transpose = false});
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
UlmBuilder::SelfAttentionIncludeResidual(std::shared_ptr<Tensor> input,
|
||||
SelfAttentionArgs args,
|
||||
const SelfAttentionWeights& params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto pre_attention, RmsNorm(input, params.pre_norm));
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
auto post_attention,
|
||||
SelfAttentionExcludeNorm(pre_attention, std::move(args), params));
|
||||
|
||||
ASSIGN_OR_RETURN(auto post_norm, RmsNorm(post_attention, params.post_norm));
|
||||
|
||||
return ElementAdd(input, post_norm);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::FeedForwardExcludeResidual(
|
||||
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto first_rms_norm, RmsNorm(input, params.pre_norm));
|
||||
|
||||
ASSIGN_OR_RETURN(auto layer_1, FullConn(first_rms_norm, params.layer_1_weight,
|
||||
params.layer_1_bias));
|
||||
|
||||
ASSIGN_OR_RETURN(auto layer_1_gate_before_gelu,
|
||||
FullConn(first_rms_norm, params.layer_1_gate_weight,
|
||||
params.layer_1_gate_bias));
|
||||
ASSIGN_OR_RETURN(auto layer_1_gate, Gelu(layer_1_gate_before_gelu));
|
||||
|
||||
ASSIGN_OR_RETURN(auto layer_1_and_gate, ElementMul(layer_1, layer_1_gate));
|
||||
if (params.opt_padding) {
|
||||
// activations *= 1.0 - paddings
|
||||
ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f));
|
||||
ASSIGN_OR_RETURN(tmp, ElementMul(layer_1_and_gate, tmp));
|
||||
ASSIGN_OR_RETURN(layer_1_and_gate, ElementAdd(tmp, layer_1_and_gate));
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto layer_2,
|
||||
FullConn(layer_1_and_gate, params.layer_2_weight, params.layer_2_bias));
|
||||
if (params.opt_padding) {
|
||||
// activations *= 1.0 - paddings
|
||||
ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f));
|
||||
ASSIGN_OR_RETURN(tmp, ElementMul(layer_2, tmp));
|
||||
ASSIGN_OR_RETURN(layer_2, ElementAdd(tmp, layer_2));
|
||||
}
|
||||
|
||||
return RmsNorm(layer_2, params.post_norm);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::FeedForwardIncludeResidual(
|
||||
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
|
||||
absl::SourceLocation loc) {
|
||||
ASSIGN_OR_RETURN(auto before_residual,
|
||||
FeedForwardExcludeResidual(input, params));
|
||||
return ElementAdd(before_residual, input);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateUlm(
|
||||
absl::string_view weights_folder, const UlmParams& ulm_params,
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs) {
|
||||
auto weight_loader =
|
||||
std::make_unique<DefaultUlmWeightsLoader>(weights_folder, ulm_params);
|
||||
return CreateUlm(std::move(weight_loader), std::move(runtime_configs));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateOneTokenUlm(
|
||||
std::unique_ptr<UlmWeightsLoader> weight_loader,
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs) {
|
||||
UlmBuilder builder;
|
||||
// TODO: might be memory waste here, benchmark.
|
||||
weight_loader->SetBuilder(builder);
|
||||
ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
|
||||
|
||||
UlmParams ulm_params = weight_loader->ulm_params();
|
||||
ulm_params.enable_kv_cache = true;
|
||||
|
||||
weight_loader->ulm_params().enable_kv_cache = true;
|
||||
weight_loader->ulm_params().final_norm = false;
|
||||
weight_loader->ulm_params().final_project = false;
|
||||
ASSIGN_OR_RETURN(auto full_ulm, CreateUlm(std::move(weight_loader)));
|
||||
|
||||
ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, 1,
|
||||
ulm_params.model_dim_D}));
|
||||
ASSIGN_OR_RETURN(auto atten_masks,
|
||||
builder.NewInput({1, ulm_params.seq_size_T}));
|
||||
ASSIGN_OR_RETURN(auto segment_pos,
|
||||
builder.NewWeight({1, ulm_params.head_dim_H}));
|
||||
// To allocate buffer before creating runtime.
|
||||
MP_RETURN_IF_ERROR(segment_pos->LoadFromVec({}, /*exact_match=*/false));
|
||||
|
||||
std::vector<KVCache>& kv_cache = full_ulm->kv_cache_;
|
||||
RET_CHECK_EQ(kv_cache.size(), ulm_params.num_transformer_M);
|
||||
|
||||
auto inter_layer = input;
|
||||
for (int i = 0; i < ulm_params.num_transformer_M; ++i) {
|
||||
const auto& sa = weights.sas[i];
|
||||
ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual(
|
||||
inter_layer,
|
||||
{.atten_mask = atten_masks,
|
||||
.segment_pos = segment_pos,
|
||||
.cache = &kv_cache[i]},
|
||||
sa));
|
||||
|
||||
auto& ff = weights.ffs[i];
|
||||
// ff.opt_padding = paddings;
|
||||
ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff));
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> logits_output, transformer_output, normed_output;
|
||||
|
||||
if (ulm_params.final_norm) {
|
||||
ASSIGN_OR_RETURN(inter_layer,
|
||||
builder.RmsNorm(inter_layer, weights.final_ln_scale));
|
||||
normed_output = inter_layer;
|
||||
normed_output->MarkOutput();
|
||||
}
|
||||
if (ulm_params.final_project) {
|
||||
RET_CHECK(weights.softmax_linear);
|
||||
ASSIGN_OR_RETURN(logits_output,
|
||||
ApplyFinalProj(inter_layer, weights, builder));
|
||||
logits_output->MarkOutput();
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs)));
|
||||
Ulm* full_ulm_p = full_ulm.get();
|
||||
auto result =
|
||||
std::make_unique<OneTokenUlm>(std::move(full_ulm), std::move(*graph));
|
||||
{
|
||||
Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D};
|
||||
result->pos_embedding_data_ =
|
||||
std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
|
||||
result->pos_embedding_data_->Borrow(full_ulm_p->pos_embedding_data_);
|
||||
}
|
||||
result->transformer_input_ = input;
|
||||
result->transformer_output_ = transformer_output;
|
||||
result->normed_output_ = normed_output;
|
||||
result->logits_output_ = logits_output;
|
||||
result->segment_pos_ = segment_pos;
|
||||
result->atten_masks_ = atten_masks;
|
||||
if (ulm_params.use_padding) {
|
||||
// result->paddings_ = paddings;
|
||||
}
|
||||
result->kv_cache_ = std::move(kv_cache);
|
||||
|
||||
result->weights_ = std::move(weights);
|
||||
result->ulm_params_ = ulm_params;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateUlm(
|
||||
std::unique_ptr<UlmWeightsLoader> weight_loader,
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs) {
|
||||
UlmBuilder builder;
|
||||
weight_loader->SetBuilder(builder);
|
||||
const auto& ulm_params = weight_loader->ulm_params();
|
||||
RET_CHECK_NE(ulm_params.batch_size_B, 0);
|
||||
|
||||
ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B,
|
||||
ulm_params.seq_size_T,
|
||||
ulm_params.model_dim_D}));
|
||||
ASSIGN_OR_RETURN(auto atten_masks, builder.NewInput({ulm_params.seq_size_T,
|
||||
ulm_params.seq_size_T}));
|
||||
VLOG(1) << "atten mask id " << atten_masks->tensor_id;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto segment_pos,
|
||||
builder.NewWeight({ulm_params.seq_size_T, ulm_params.head_dim_H}));
|
||||
MP_RETURN_IF_ERROR(FillXnnRoPEWeights(*segment_pos));
|
||||
VLOG(1) << "segment pos id " << segment_pos->tensor_id;
|
||||
std::shared_ptr<Tensor> paddings;
|
||||
if (ulm_params.use_padding) {
|
||||
ASSIGN_OR_RETURN(paddings, builder.NewInput({ulm_params.batch_size_B,
|
||||
ulm_params.seq_size_T, 1}));
|
||||
VLOG(1) << "paddings id " << paddings->tensor_id;
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
|
||||
std::vector<KVCache> kv_cache;
|
||||
|
||||
auto inter_layer = input;
|
||||
for (int i = 0; i < ulm_params.num_transformer_M; ++i) {
|
||||
const auto& sa = weights.sas[i];
|
||||
KVCache* cache = nullptr;
|
||||
if (ulm_params.enable_kv_cache) {
|
||||
auto k_cache = std::make_shared<Tensor>(
|
||||
Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T,
|
||||
ulm_params.n_heads_N, ulm_params.head_dim_H});
|
||||
MP_RETURN_IF_ERROR(k_cache->LoadFromVec({}, /*exact_match=*/false));
|
||||
auto v_cache = std::make_shared<Tensor>(
|
||||
Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T,
|
||||
ulm_params.n_heads_N, ulm_params.head_dim_H});
|
||||
MP_RETURN_IF_ERROR(v_cache->LoadFromVec({}, /*exact_match=*/false));
|
||||
kv_cache.push_back(KVCache{.k_cache = k_cache, .v_cache = v_cache});
|
||||
cache = &kv_cache.back();
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual(
|
||||
inter_layer,
|
||||
{.atten_mask = atten_masks,
|
||||
.segment_pos = segment_pos,
|
||||
.cache = cache},
|
||||
sa));
|
||||
|
||||
auto& ff = weights.ffs[i];
|
||||
ff.opt_padding = paddings;
|
||||
ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff));
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> logits_output, transformer_output, normed_output;
|
||||
|
||||
if (!ulm_params.final_norm && !ulm_params.final_project) {
|
||||
transformer_output = inter_layer;
|
||||
transformer_output->MarkOutput();
|
||||
}
|
||||
|
||||
if (ulm_params.final_norm) {
|
||||
ASSIGN_OR_RETURN(inter_layer,
|
||||
builder.RmsNorm(inter_layer, weights.final_ln_scale));
|
||||
normed_output = inter_layer;
|
||||
normed_output->MarkOutput();
|
||||
}
|
||||
|
||||
if (ulm_params.final_project) {
|
||||
RET_CHECK(weights.softmax_linear);
|
||||
ASSIGN_OR_RETURN(logits_output,
|
||||
ApplyFinalProj(inter_layer, weights, builder));
|
||||
logits_output->MarkOutput();
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs)));
|
||||
auto ulm = std::make_unique<Ulm>(std::move(*graph));
|
||||
{
|
||||
ASSIGN_OR_RETURN(auto pos_embedding_data,
|
||||
mediapipe::PositionEmbedding(ulm_params.seq_size_T,
|
||||
ulm_params.model_dim_D));
|
||||
Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D};
|
||||
ulm->pos_embedding_data_ =
|
||||
std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
|
||||
MP_RETURN_IF_ERROR(
|
||||
ulm->pos_embedding_data_->LoadFromVec(pos_embedding_data));
|
||||
}
|
||||
ulm->transformer_input_ = input;
|
||||
ulm->transformer_output_ = transformer_output;
|
||||
ulm->normed_output_ = normed_output;
|
||||
ulm->logits_output_ = logits_output;
|
||||
ulm->segment_pos_ = segment_pos;
|
||||
ulm->atten_masks_ = atten_masks;
|
||||
if (ulm_params.use_padding) {
|
||||
ulm->paddings_ = paddings;
|
||||
}
|
||||
ulm->kv_cache_ = std::move(kv_cache);
|
||||
|
||||
ulm->weights_ = std::move(weights);
|
||||
ulm->ulm_params_ = ulm_params;
|
||||
|
||||
return ulm;
|
||||
}
|
||||
|
||||
absl::Status Ulm::InitInputTokens(const std::vector<int>& input_ids) {
|
||||
prev_ids_ = input_ids;
|
||||
|
||||
constexpr float neg_value = 0.7 * std::numeric_limits<float>::lowest();
|
||||
const auto& seq_size = ulm_params_.seq_size_T;
|
||||
std::vector<float> attention_array(seq_size * seq_size, neg_value);
|
||||
for (int i = 0; i < seq_size; ++i) {
|
||||
for (int j = 0; j < seq_size; ++j) {
|
||||
if (i < input_ids.size() && j < input_ids.size()) {
|
||||
attention_array[seq_size * i + j] = 0;
|
||||
} else if (i >= seq_size && j <= i) {
|
||||
attention_array[seq_size * i + j] = 0;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(atten_masks_->LoadFromVec(attention_array));
|
||||
|
||||
MP_RETURN_IF_ERROR(GetTokenEmbedding(input_ids,
|
||||
pos_embedding_data_->DataAs<float>(),
|
||||
transformer_input_->DataAs<float>()));
|
||||
return SetupRuntime();
|
||||
}
|
||||
|
||||
absl::Status Ulm::GetNextToken(std::vector<int>* output_ids) {
|
||||
VLOG(2) << "Decode step " << prev_ids_.size() - 1;
|
||||
|
||||
MP_RETURN_IF_ERROR(Run());
|
||||
|
||||
RET_CHECK(logits_output_);
|
||||
std::shared_ptr<Tensor> logits =
|
||||
logits_output_->Slice({0, prev_ids_.size() - 1, 0});
|
||||
DCHECK_EQ(logits->num_elements, ulm_params_.voc_size_V);
|
||||
|
||||
ASSIGN_OR_RETURN(*output_ids,
|
||||
mediapipe::SampleNextToken(
|
||||
logits->DataAs<float>(),
|
||||
/*batch_size=*/1,
|
||||
/*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10,
|
||||
/*top_p=*/1, /*temperature=*/-1));
|
||||
RET_CHECK_EQ(output_ids->size(), 1);
|
||||
prev_ids_.push_back(output_ids->at(0));
|
||||
|
||||
return GetTokenEmbedding(
|
||||
*output_ids,
|
||||
pos_embedding_data_->Slice({prev_ids_.size() - 1, 0})->DataAs<float>(),
|
||||
transformer_input_->Slice({0, prev_ids_.size() - 1, 0})->DataAs<float>());
|
||||
}
|
||||
|
||||
absl::Status Ulm::GetTokenEmbedding(const std::vector<int>& ids,
|
||||
const float* pos_embedding_data,
|
||||
float* embedding) {
|
||||
auto token_embedding = weights_.token_embedding ? weights_.token_embedding
|
||||
: weights_.softmax_linear;
|
||||
RET_CHECK(token_embedding->dims[0] == ulm_params_.voc_size_V)
|
||||
<< "shape must be [vocab_size, _], such that following Slice() makes "
|
||||
"sense.";
|
||||
for (size_t id : ids) {
|
||||
memcpy(embedding, token_embedding->Slice(0, id)->Data(),
|
||||
ulm_params_.model_dim_D * sizeof(float));
|
||||
for (size_t i = 0; i < ulm_params_.model_dim_D; ++i) {
|
||||
embedding[i] += pos_embedding_data[i];
|
||||
}
|
||||
pos_embedding_data += ulm_params_.model_dim_D;
|
||||
embedding += ulm_params_.model_dim_D;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
127
mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h
Normal file
127
mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h
Normal file
|
@ -0,0 +1,127 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
class Ulm : public XnnGraph {
|
||||
public:
|
||||
using UlmParams = UlmParams;
|
||||
|
||||
explicit Ulm(XnnGraph&& other) : XnnGraph(std::move(other)) {}
|
||||
~Ulm() override = default;
|
||||
|
||||
// Creating ULM graph with default params. The default param corresponds to
|
||||
// ULM1B 256k model.
|
||||
static absl::StatusOr<std::unique_ptr<Ulm>> CreateUlm(
|
||||
absl::string_view weights_folder,
|
||||
const UlmParams& ulm_params =
|
||||
UlmParams{
|
||||
.num_transformer_M = 18,
|
||||
.batch_size_B = 1,
|
||||
.seq_size_T = 16,
|
||||
.model_dim_D = 1536,
|
||||
.hidden_dim_HD = 8 * 1536,
|
||||
.head_dim_H = 128,
|
||||
.n_heads_N = 12,
|
||||
.voc_size_V = 256128,
|
||||
},
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
|
||||
static absl::StatusOr<std::unique_ptr<Ulm>> CreateUlm(
|
||||
std::unique_ptr<UlmWeightsLoader> weight_loader,
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
|
||||
// Build the graph for one-token inference.
|
||||
static absl::StatusOr<std::unique_ptr<Ulm>> CreateOneTokenUlm(
|
||||
std::unique_ptr<UlmWeightsLoader> weight_loader,
|
||||
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
|
||||
|
||||
// (Re)Initialize with input token ids. This will reset the cache, mask etc.
|
||||
virtual absl::Status InitInputTokens(const std::vector<int>& input_ids);
|
||||
|
||||
// Get the next token id.
|
||||
virtual absl::Status GetNextToken(std::vector<int>* output_ids);
|
||||
|
||||
protected:
|
||||
friend class OneTokenUlm;
|
||||
friend class UlmTest;
|
||||
friend class UlmBuilder;
|
||||
|
||||
// Enable if enable_kv_cache
|
||||
struct KVCache {
|
||||
std::shared_ptr<Tensor> k_cache;
|
||||
std::shared_ptr<Tensor> v_cache;
|
||||
std::shared_ptr<Tensor> k_slice;
|
||||
std::shared_ptr<Tensor> v_slice;
|
||||
};
|
||||
|
||||
absl::Status GetTokenEmbedding(const std::vector<int>& ids,
|
||||
const float* pos_embedding_data,
|
||||
float* embedding);
|
||||
|
||||
UlmWeights weights_;
|
||||
UlmParams ulm_params_;
|
||||
|
||||
std::shared_ptr<Tensor> pos_embedding_data_;
|
||||
std::shared_ptr<Tensor> atten_masks_;
|
||||
std::shared_ptr<Tensor> segment_pos_;
|
||||
std::shared_ptr<Tensor> paddings_;
|
||||
|
||||
std::shared_ptr<Tensor> transformer_input_;
|
||||
std::shared_ptr<Tensor> transformer_output_;
|
||||
std::shared_ptr<Tensor> normed_output_;
|
||||
std::shared_ptr<Tensor> logits_output_;
|
||||
|
||||
// Previous ids, including prompt.
|
||||
std::vector<int> prev_ids_;
|
||||
// If enable_kv_cache, expect a mask of [0, ... 0, 1, 0, 0...], size 1 x T.
|
||||
std::shared_ptr<Tensor> decode_step_mask_;
|
||||
// [1, 1, ..., 1, 0, 0...], applied on cache
|
||||
std::shared_ptr<Tensor> decode_step_mask_for_cache_;
|
||||
std::vector<KVCache> kv_cache_;
|
||||
};
|
||||
|
||||
class UlmBuilder : public XnnGraphBuilder {
|
||||
public:
|
||||
struct SelfAttentionArgs {
|
||||
std::shared_ptr<Tensor> atten_mask;
|
||||
std::shared_ptr<Tensor> segment_pos;
|
||||
|
||||
Ulm::KVCache* cache = nullptr;
|
||||
};
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionExcludeNorm(
|
||||
std::shared_ptr<Tensor> input, SelfAttentionArgs args,
|
||||
const SelfAttentionWeights& sa_weights,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionIncludeResidual(
|
||||
std::shared_ptr<Tensor> input, SelfAttentionArgs args,
|
||||
const SelfAttentionWeights& params,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardExcludeResidual(
|
||||
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardIncludeResidual(
|
||||
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
|
||||
absl::SourceLocation loc = absl::SourceLocation::current());
|
||||
};
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
|
366
mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc
Normal file
366
mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc
Normal file
|
@ -0,0 +1,366 @@
|
|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "file/base/filesystem.h"
|
||||
#include "file/base/options.h"
|
||||
#include "file/base/path.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
#include "third_party/XNNPACK/include/xnnpack.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
namespace {
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefixHelper(
|
||||
XnnGraphBuilder& builder, absl::string_view prefix,
|
||||
const Tensor::DimsType& dims, size_t dim_scale_if_any) {
|
||||
RET_CHECK(!prefix.empty() && prefix.back() != '.');
|
||||
std::vector<std::string> filenames;
|
||||
auto s = file::Match(absl::StrCat(prefix, "*"), &filenames, file::Defaults());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << s;
|
||||
return nullptr;
|
||||
} else if (filenames.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (filenames.size() == 1) {
|
||||
RET_CHECK_EQ(filenames[0], prefix);
|
||||
return builder.NewWeight(filenames[0], dims);
|
||||
}
|
||||
|
||||
bool is_quantized_tensor = false;
|
||||
for (const auto& filename : filenames) {
|
||||
if (absl::StrContains(filename, kQuantizedScaleSuffix)) {
|
||||
is_quantized_tensor = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
RET_CHECK(is_quantized_tensor)
|
||||
<< "At least one of {" << filenames << "} must be quantize scale file.";
|
||||
|
||||
std::shared_ptr<Tensor> result;
|
||||
result = std::make_shared<QCTensor>(dims, dim_scale_if_any);
|
||||
|
||||
MP_RETURN_IF_ERROR(result->LoadFromFile(prefix));
|
||||
builder.NewWeight(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status TransposeSelfAttentionWeight(
|
||||
const UlmWeightsLoader& loader, std::shared_ptr<Tensor>& original_weight,
|
||||
absl::string_view cache_file_prefix) {
|
||||
const auto& ulm_param = loader.ulm_params();
|
||||
RET_CHECK(original_weight);
|
||||
|
||||
std::optional<int> from_cache =
|
||||
original_weight->GetMetadata(UlmWeights::kKeyLoadedFromCache);
|
||||
if (from_cache && *from_cache) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (auto s = original_weight->DumpToFile(cache_file_prefix); !s.ok()) {
|
||||
LOG(WARNING) << s;
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(original_weight->LoadFromFile(cache_file_prefix));
|
||||
}
|
||||
loader.builder().NewWeight(original_weight);
|
||||
original_weight->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight,
|
||||
ulm_param.n_heads_N);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status PrepareTokenEmbeddingDecorator::Decorate(
|
||||
const UlmWeightsLoader& loader, UlmWeights& weight) {
|
||||
if (weight.token_embedding) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto& ulm_params = loader.ulm_params();
|
||||
absl::string_view cache_path = loader.ulm_params().weight_cache_path;
|
||||
std::string token_embedding_cache_path =
|
||||
cache_path.empty() ? "" : file::JoinPath(cache_path, "token_embedding.w");
|
||||
// 1. try cache
|
||||
if (!token_embedding_cache_path.empty()) {
|
||||
auto token_embedding =
|
||||
Tensor::FromFile(token_embedding_cache_path,
|
||||
{ulm_params.voc_size_V, ulm_params.model_dim_D});
|
||||
if (token_embedding.ok()) {
|
||||
weight.token_embedding = *token_embedding;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
// 2. fill embedding from softmax_linear
|
||||
auto& softmax_linear = *weight.softmax_linear;
|
||||
RET_CHECK(softmax_linear.dims[0] == ulm_params.voc_size_V) << softmax_linear;
|
||||
if (softmax_linear.datatype == xnn_datatype_fp32) {
|
||||
weight.token_embedding = softmax_linear.View();
|
||||
} else if (softmax_linear.datatype == xnn_datatype_qcint8) {
|
||||
ASSIGN_OR_RETURN(weight.token_embedding, softmax_linear.ConvertToF32());
|
||||
}
|
||||
|
||||
float* embedding_data = weight.token_embedding->DataAs<float>();
|
||||
for (size_t i = 0; i < softmax_linear.num_elements; ++i) {
|
||||
embedding_data[i] *= std::sqrt(loader.ulm_params().model_dim_D);
|
||||
}
|
||||
|
||||
// 3. save cache
|
||||
if (!token_embedding_cache_path.empty()) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
weight.token_embedding->DumpToFile(token_embedding_cache_path));
|
||||
return weight.token_embedding->LoadFromFile(token_embedding_cache_path);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TransposeSelfAttentionWeightDecorator::Decorate(
|
||||
const UlmWeightsLoader& loader, UlmWeights& weight) {
|
||||
absl::string_view cache_path = loader.ulm_params().weight_cache_path;
|
||||
if (cache_path.empty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < weight.sas.size(); ++i) {
|
||||
auto& sa = weight.sas[i];
|
||||
auto prefix = absl::StrCat(UlmWeightsLoader::kTransformerWeightPrefix, i,
|
||||
".self_attention.");
|
||||
MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
|
||||
loader, sa.k_weight,
|
||||
file::JoinPath(cache_path, absl::StrCat(prefix, "k.w"))));
|
||||
MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
|
||||
loader, sa.q_weight,
|
||||
file::JoinPath(cache_path, absl::StrCat(prefix, "q.w"))));
|
||||
MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
|
||||
loader, sa.v_weight,
|
||||
file::JoinPath(cache_path, absl::StrCat(prefix, "v.w"))));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmWeightsLoader::LoadFromAbsPathPrefix(
|
||||
absl::string_view prefix, const Tensor::DimsType& dims,
|
||||
size_t dim_scale_if_any) const {
|
||||
return LoadFromAbsPathPrefixHelper(*builder_, prefix, dims, dim_scale_if_any);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
UlmWeightsLoader::TryCacheThenLoadSelfAttention(
|
||||
absl::string_view filename_prefix) const {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto r,
|
||||
TryCacheThenLoadWeightTranspose(
|
||||
filename_prefix,
|
||||
{params_.model_dim_D, params_.n_heads_N * params_.head_dim_H}, 1));
|
||||
r->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight,
|
||||
params_.n_heads_N);
|
||||
return r;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
UlmWeightsLoader::TryCacheThenLoadFeedForward(
|
||||
absl::string_view filename_prefix,
|
||||
std::optional<Tensor::DimsType> dims) const {
|
||||
if (!dims) {
|
||||
dims = {params_.model_dim_D, params_.hidden_dim_HD};
|
||||
}
|
||||
return TryCacheThenLoadWeightTranspose(filename_prefix, *dims, 1);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
UlmWeightsLoader::TryCacheThenLoadWeightTranspose(
|
||||
absl::string_view filename_prefix, Tensor::DimsType original_dims,
|
||||
size_t original_dim_cale) const {
|
||||
if (!params_.weight_cache_path.empty()) {
|
||||
auto cache_full_prefix =
|
||||
file::JoinPath(params_.weight_cache_path, filename_prefix);
|
||||
Tensor::DimsType cache_dim{original_dims.rbegin(), original_dims.rend()};
|
||||
ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix(
|
||||
cache_full_prefix, std::move(cache_dim),
|
||||
/*dim_scale_if_any=*/1 - original_dim_cale));
|
||||
if (r) {
|
||||
r->SetMetadata(UlmWeights::kKeyLoadedFromCache, 1);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix(
|
||||
file::JoinPath(weight_path_, filename_prefix),
|
||||
std::move(original_dims),
|
||||
/*dim_scale_if_any=*/original_dim_cale));
|
||||
RET_CHECK(r) << file::JoinPath(weight_path_, filename_prefix);
|
||||
r = r->Transpose();
|
||||
builder_->NewWeight(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
absl::StatusOr<FeedForwardWeights> UlmWeightsLoader::LoadFeedForward(
|
||||
int layer_id) {
|
||||
absl::string_view weights_folder = weight_path_;
|
||||
const auto& params = params_;
|
||||
auto ff_file_prefix =
|
||||
absl::StrCat(kTransformerWeightPrefix, layer_id, ".ff_layer.");
|
||||
auto ff_prefix = file::JoinPath(weights_folder, ff_file_prefix);
|
||||
FeedForwardWeights feed_forward;
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
feed_forward.pre_norm,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "pre_layer_norm.scale"),
|
||||
{params.model_dim_D}));
|
||||
ASSIGN_OR_RETURN(
|
||||
feed_forward.post_norm,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "post_layer_norm.scale"),
|
||||
{params.model_dim_D}));
|
||||
ASSIGN_OR_RETURN(
|
||||
feed_forward.layer_1_bias,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1.bias.b"),
|
||||
{params.hidden_dim_HD}));
|
||||
ASSIGN_OR_RETURN(feed_forward.layer_1_weight,
|
||||
TryCacheThenLoadFeedForward(
|
||||
absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w")));
|
||||
ASSIGN_OR_RETURN(
|
||||
feed_forward.layer_1_gate_bias,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1_gate.bias.b"),
|
||||
{params.hidden_dim_HD}));
|
||||
ASSIGN_OR_RETURN(feed_forward.layer_1_gate_weight,
|
||||
TryCacheThenLoadFeedForward(absl::StrCat(
|
||||
ff_file_prefix, "ffn_layer1_gate.linear.w")));
|
||||
ASSIGN_OR_RETURN(
|
||||
feed_forward.layer_2_bias,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer2.bias.b"),
|
||||
{params.model_dim_D}, /*dim_scale_if_any=*/0));
|
||||
ASSIGN_OR_RETURN(
|
||||
feed_forward.layer_2_weight,
|
||||
TryCacheThenLoadFeedForward(
|
||||
absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"),
|
||||
Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D}));
|
||||
|
||||
return feed_forward;
|
||||
}
|
||||
|
||||
absl::StatusOr<SelfAttentionWeights> UlmWeightsLoader::LoadSelfAttention(
|
||||
int layer_id) {
|
||||
absl::string_view weights_folder = weight_path_;
|
||||
const auto& params = params_;
|
||||
SelfAttentionWeights self_attention;
|
||||
|
||||
auto sa_file_prefix = absl::StrCat(kTransformerWeightPrefix, layer_id);
|
||||
auto sa_prefix = file::JoinPath(weights_folder, sa_file_prefix);
|
||||
ASSIGN_OR_RETURN(
|
||||
self_attention.pre_norm,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".pre_layer_norm.scale"),
|
||||
{params.model_dim_D}));
|
||||
ASSIGN_OR_RETURN(
|
||||
self_attention.post_norm,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".post_layer_norm.scale"),
|
||||
{params.model_dim_D}));
|
||||
|
||||
absl::StrAppend(&sa_file_prefix, ".self_attention.");
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
self_attention.k_weight,
|
||||
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w")));
|
||||
ASSIGN_OR_RETURN(
|
||||
self_attention.q_weight,
|
||||
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w")));
|
||||
ASSIGN_OR_RETURN(
|
||||
self_attention.v_weight,
|
||||
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w")));
|
||||
|
||||
sa_prefix = file::JoinPath(weights_folder, sa_file_prefix);
|
||||
ASSIGN_OR_RETURN(self_attention.per_dim_scale,
|
||||
LoadFromAbsPathPrefix(
|
||||
absl::StrCat(sa_prefix, "per_dim_scale.per_dim_scale"),
|
||||
{params.head_dim_H}));
|
||||
ASSIGN_OR_RETURN(self_attention.post_proj_weight,
|
||||
LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, "post.w"),
|
||||
{params.model_dim_D,
|
||||
params.n_heads_N * params.head_dim_H},
|
||||
/*dim_scale_if_any=*/0));
|
||||
|
||||
return self_attention;
|
||||
}
|
||||
|
||||
absl::StatusOr<UlmWeights> UlmWeightsLoader::LoadWeights() {
|
||||
absl::string_view weights_folder = weight_path_;
|
||||
const auto& params = params_;
|
||||
UlmWeights result;
|
||||
|
||||
for (int layer_id = 0; layer_id < params.num_transformer_M; ++layer_id) {
|
||||
ASSIGN_OR_RETURN(auto ff, LoadFeedForward(layer_id));
|
||||
result.ffs.push_back(std::move(ff));
|
||||
ASSIGN_OR_RETURN(auto sa, LoadSelfAttention(layer_id));
|
||||
result.sas.push_back(std::move(sa));
|
||||
}
|
||||
if (params.final_norm) {
|
||||
ASSIGN_OR_RETURN(result.final_ln_scale,
|
||||
LoadFromAbsPathPrefix(
|
||||
file::JoinPath(weights_folder, kFinalScaleFilename),
|
||||
{params.model_dim_D}));
|
||||
}
|
||||
ASSIGN_OR_RETURN(result.softmax_bias,
|
||||
LoadFromAbsPathPrefix(
|
||||
file::JoinPath(weights_folder, kLogitsFfnBiasFilename),
|
||||
{params.voc_size_V}));
|
||||
ASSIGN_OR_RETURN(result.softmax_linear,
|
||||
TryCacheThenLoadWeightTranspose(
|
||||
kLogitsFfnWeightFilename,
|
||||
{params.model_dim_D, params.voc_size_V}, 1));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
BenchmarkUlmWeightsLoader::BenchmarkUlmWeightsLoader(const UlmParams& params,
|
||||
xnn_datatype data_type)
|
||||
: DefaultUlmWeightsLoader("", params), data_type_(data_type) {
|
||||
params_.weight_cache_path.clear();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
BenchmarkUlmWeightsLoader::TryCacheThenLoadWeightTranspose(
|
||||
absl::string_view filename_prefix, Tensor::DimsType original_dims,
|
||||
size_t original_dim_cale) const {
|
||||
auto result = std::make_shared<QCTensor>(
|
||||
Tensor::DimsType{original_dims.rbegin(), original_dims.rend()},
|
||||
1 - original_dim_cale);
|
||||
auto real_data = std::make_shared<std::string>(result->num_elements, 0xA5);
|
||||
result->flat_data = std::shared_ptr<char>(real_data, real_data->data());
|
||||
auto real_scale = std::make_shared<std::vector<float>>(
|
||||
original_dims[original_dim_cale], 1.0f);
|
||||
result->scale_data = std::shared_ptr<float>(real_scale, real_scale->data());
|
||||
builder_->NewWeight(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
BenchmarkUlmWeightsLoader::LoadFromAbsPathPrefix(
|
||||
absl::string_view prefix, const Tensor::DimsType& dims,
|
||||
size_t dim_scale_if_any) const {
|
||||
// If loader calls this function directly, it's always non-quantized weights.
|
||||
auto result = std::make_shared<Tensor>(dims);
|
||||
MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
|
||||
builder_->NewWeight(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
192
mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h
Normal file
192
mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h
Normal file
|
@ -0,0 +1,192 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
#include "third_party/XNNPACK/include/xnnpack.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
struct UlmParams {
|
||||
size_t num_transformer_M = 18;
|
||||
size_t batch_size_B = 1;
|
||||
size_t seq_size_T = 16;
|
||||
size_t model_dim_D = 1536;
|
||||
size_t hidden_dim_HD = 8 * 1536;
|
||||
size_t head_dim_H = 128;
|
||||
size_t n_heads_N = 12;
|
||||
size_t voc_size_V = 32000;
|
||||
|
||||
bool use_padding = true;
|
||||
bool final_norm = true;
|
||||
bool final_project = true;
|
||||
|
||||
bool enable_kv_cache = false;
|
||||
// Path to store reshaped weights as cache. Set empty to disable caching.
|
||||
std::string weight_cache_path;
|
||||
};
|
||||
|
||||
struct SelfAttentionWeights {
|
||||
std::shared_ptr<Tensor> pre_norm;
|
||||
|
||||
std::shared_ptr<Tensor> k_weight;
|
||||
std::shared_ptr<Tensor> q_weight;
|
||||
std::shared_ptr<Tensor> v_weight;
|
||||
std::shared_ptr<Tensor> per_dim_scale;
|
||||
std::shared_ptr<Tensor> post_proj_weight;
|
||||
|
||||
std::shared_ptr<Tensor> post_norm;
|
||||
};
|
||||
|
||||
struct FeedForwardWeights {
|
||||
std::shared_ptr<Tensor> pre_norm;
|
||||
std::shared_ptr<Tensor> layer_1_weight;
|
||||
std::shared_ptr<Tensor> layer_1_bias;
|
||||
std::shared_ptr<Tensor> layer_1_gate_weight;
|
||||
std::shared_ptr<Tensor> layer_1_gate_bias;
|
||||
std::shared_ptr<Tensor> layer_2_weight;
|
||||
std::shared_ptr<Tensor> layer_2_bias;
|
||||
std::shared_ptr<Tensor> post_norm;
|
||||
|
||||
std::shared_ptr<Tensor> opt_padding;
|
||||
};
|
||||
|
||||
struct UlmWeights {
|
||||
std::vector<FeedForwardWeights> ffs;
|
||||
std::vector<SelfAttentionWeights> sas;
|
||||
std::shared_ptr<Tensor> final_ln_scale;
|
||||
std::shared_ptr<Tensor> softmax_linear;
|
||||
std::shared_ptr<Tensor> softmax_bias;
|
||||
|
||||
// Optional. Usually softmax_linear can be used as embedding, but sometimes we
|
||||
// need to scale/transpose it.
|
||||
std::shared_ptr<Tensor> token_embedding;
|
||||
|
||||
static constexpr absl::string_view kKeyLoadedFromCache{"loaded_from_cache"};
|
||||
};
|
||||
|
||||
class UlmWeightsLoader {
|
||||
public:
|
||||
constexpr static absl::string_view kTransformerWeightPrefix{
|
||||
"params.lm.transformer.x_layers_"};
|
||||
constexpr static absl::string_view kFinalScaleFilename{
|
||||
"params.lm.final_ln.scale"};
|
||||
constexpr static absl::string_view kLogitsFfnBiasFilename{
|
||||
"params.lm.softmax.logits_ffn.bias.b"};
|
||||
constexpr static absl::string_view kLogitsFfnWeightFilename{
|
||||
"params.lm.softmax.logits_ffn.linear.w"};
|
||||
|
||||
UlmWeightsLoader(absl::string_view weight_path, const UlmParams& params)
|
||||
: weight_path_(weight_path), params_(params) {}
|
||||
virtual ~UlmWeightsLoader() = default;
|
||||
|
||||
void SetBuilder(XnnGraphBuilder& builder) { builder_ = &builder; }
|
||||
|
||||
virtual absl::StatusOr<UlmWeights> LoadWeights();
|
||||
|
||||
virtual absl::StatusOr<SelfAttentionWeights> LoadSelfAttention(int layer_id);
|
||||
virtual absl::StatusOr<FeedForwardWeights> LoadFeedForward(int layer_id);
|
||||
|
||||
UlmParams& ulm_params() { return params_; }
|
||||
const UlmParams& ulm_params() const { return params_; }
|
||||
XnnGraphBuilder& builder() const { return *builder_; }
|
||||
|
||||
protected:
|
||||
// Find the files that matches prefix, then read from file.
|
||||
virtual absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
|
||||
absl::string_view prefix, const Tensor::DimsType& dims,
|
||||
size_t dim_scale_if_any) const;
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
|
||||
absl::string_view prefix, const Tensor::DimsType& dims) const {
|
||||
return LoadFromAbsPathPrefix(prefix, dims, 0);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadSelfAttention(
|
||||
absl::string_view filename_prefix) const;
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadFeedForward(
|
||||
absl::string_view filename_prefix,
|
||||
std::optional<Tensor::DimsType> dims = std::nullopt) const;
|
||||
virtual absl::StatusOr<std::shared_ptr<Tensor>>
|
||||
TryCacheThenLoadWeightTranspose(absl::string_view filename_prefix,
|
||||
Tensor::DimsType original_dims,
|
||||
size_t original_dim_cale) const;
|
||||
|
||||
std::string weight_path_;
|
||||
UlmParams params_;
|
||||
XnnGraphBuilder* builder_ = nullptr;
|
||||
};
|
||||
|
||||
// Try: 1. load token embedding from cache; 2. fill token embedding by transpose
|
||||
// softmax linear then scale; 3. dump token embedding to cache.
|
||||
struct PrepareTokenEmbeddingDecorator {
|
||||
static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
|
||||
};
|
||||
struct TransposeSoftmaxWeightDecorator {
|
||||
static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
|
||||
};
|
||||
struct TransposeSelfAttentionWeightDecorator {
|
||||
// If KQV weight are reshaped, ignore.
|
||||
// If KQV weight are not properly shaped, load from cache if any, or build.
|
||||
// If KQV weight are missing, try loading from cache path, or fail if missing.
|
||||
static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
|
||||
};
|
||||
|
||||
// Apply some decoration (in order) to the weights loaded by base class.
|
||||
template <class... Decorators>
|
||||
class UlmWeightsLoaderWith : public UlmWeightsLoader {
|
||||
public:
|
||||
UlmWeightsLoaderWith(absl::string_view weight_path, const UlmParams& params)
|
||||
: UlmWeightsLoader(weight_path, params),
|
||||
decorators_{Decorators::Decorate...} {}
|
||||
|
||||
absl::StatusOr<UlmWeights> LoadWeights() override {
|
||||
ASSIGN_OR_RETURN(auto result, UlmWeightsLoader::LoadWeights());
|
||||
for (const auto& decorator : decorators_) {
|
||||
MP_RETURN_IF_ERROR(decorator(*this, result));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::function<absl::Status(const UlmWeightsLoader&, UlmWeights&)>>
|
||||
decorators_;
|
||||
};
|
||||
|
||||
using DefaultUlmWeightsLoader =
|
||||
UlmWeightsLoaderWith<TransposeSelfAttentionWeightDecorator,
|
||||
PrepareTokenEmbeddingDecorator>;
|
||||
|
||||
// Generate weights with some random value.
|
||||
class BenchmarkUlmWeightsLoader : public DefaultUlmWeightsLoader {
|
||||
public:
|
||||
explicit BenchmarkUlmWeightsLoader(
|
||||
const UlmParams& params, xnn_datatype data_type = xnn_datatype_fp32);
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadWeightTranspose(
|
||||
absl::string_view filename_prefix, Tensor::DimsType original_dims,
|
||||
size_t original_dim_cale) const override;
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
|
||||
absl::string_view prefix, const Tensor::DimsType& dims,
|
||||
size_t dim_scale_if_any) const override;
|
||||
|
||||
private:
|
||||
xnn_datatype data_type_;
|
||||
std::shared_ptr<Tensor> random_value_buffer_;
|
||||
};
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
|
21
mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc
Normal file
21
mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc
Normal file
|
@ -0,0 +1,21 @@
|
|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels) {
|
||||
std::vector<float> out_array(max_seq_len * num_channels);
|
||||
for (size_t ch_id = 0; ch_id < num_channels / 2; ++ch_id) {
|
||||
auto timescale = std::pow(1e-4, 2.0 * ch_id / num_channels);
|
||||
for (size_t seq_id = 0; seq_id < max_seq_len; ++seq_id) {
|
||||
auto sinusoid_inp = seq_id * timescale;
|
||||
out_array[seq_id * num_channels + ch_id] = cos(sinusoid_inp);
|
||||
out_array[seq_id * num_channels + ch_id + num_channels / 2] =
|
||||
sin(sinusoid_inp);
|
||||
}
|
||||
}
|
||||
return out_array;
|
||||
}
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
61
mediapipe/tasks/cc/text/utils/xnn_utils/utils.h
Normal file
61
mediapipe/tasks/cc/text/utils/xnn_utils/utils.h
Normal file
|
@ -0,0 +1,61 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "file/base/options.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels);
|
||||
|
||||
// expect_size_bytes == 0 means don't check size.
|
||||
template <typename element_type = char>
|
||||
static absl::StatusOr<std::shared_ptr<element_type>> LoadBufferFromFile(
|
||||
absl::string_view file_path, bool use_mmap = true,
|
||||
size_t expect_size_bytes = 0) {
|
||||
if (use_mmap) {
|
||||
int fd = open(file_path.data(), O_RDONLY);
|
||||
RET_CHECK_GE(fd, 0) << "open " << file_path << " failed";
|
||||
auto cleanup = absl::MakeCleanup([fd] { close(fd); });
|
||||
|
||||
const size_t size = lseek(fd, 0, SEEK_END);
|
||||
if (expect_size_bytes) {
|
||||
RET_CHECK_EQ(expect_size_bytes, size)
|
||||
<< "File size " << size << ", expected " << expect_size_bytes
|
||||
<< ", file path " << file_path;
|
||||
}
|
||||
|
||||
void* data = mmap(/*addr=*/nullptr, size, /*prot=*/PROT_READ,
|
||||
/*flags=*/MAP_SHARED, fd, /*offset=*/0);
|
||||
RET_CHECK_NE(data, MAP_FAILED);
|
||||
RET_CHECK_NE(data, nullptr);
|
||||
|
||||
return std::shared_ptr<element_type>(static_cast<element_type*>(data),
|
||||
[](auto* p) {});
|
||||
} else {
|
||||
auto read_buffer = std::make_shared<std::string>();
|
||||
MP_RETURN_IF_ERROR(
|
||||
file::GetContents(file_path, read_buffer.get(), file::Defaults()));
|
||||
|
||||
if (expect_size_bytes) {
|
||||
RET_CHECK_EQ(expect_size_bytes, read_buffer->size())
|
||||
<< "File size " << read_buffer->size() << ", expected "
|
||||
<< expect_size_bytes << ", file path " << file_path;
|
||||
}
|
||||
|
||||
return std::shared_ptr<element_type>(
|
||||
read_buffer, reinterpret_cast<element_type*>(read_buffer->data()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
|
358
mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc
Normal file
358
mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc
Normal file
|
@ -0,0 +1,358 @@
|
|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "file/base/options.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
|
||||
#include "third_party/XNNPACK/include/xnnpack.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos) {
|
||||
RET_CHECK_EQ(out_seg_pos.dims.size(), 2);
|
||||
const size_t max_seq_len = out_seg_pos.dims[0];
|
||||
const size_t num_channels = out_seg_pos.dims[1];
|
||||
return out_seg_pos.LoadFromVec(FillXnnRoPEWeights(max_seq_len, num_channels));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
|
||||
os << "Tensor{dims=[" << tensor.dims << "], datatype=" << tensor.datatype
|
||||
<< ", num_elements=" << tensor.num_elements << "}";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const QCTensor& tensor) {
|
||||
os << "QCTensor{dims=[" << tensor.dims << "], dim_scale=" << tensor.dim_scale
|
||||
<< " datatype=" << tensor.datatype
|
||||
<< ", num_elements=" << tensor.num_elements << "}";
|
||||
return os;
|
||||
}
|
||||
|
||||
bool Tensor::operator==(const Tensor& other) const {
|
||||
if (dims.size() != other.dims.size()) {
|
||||
return false;
|
||||
} else if (datatype != other.datatype) {
|
||||
return false;
|
||||
} else {
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
if (dims[i] != other.dims[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0 == memcmp(Data(), other.Data(), num_elements * ElementSize());
|
||||
}
|
||||
|
||||
void Tensor::AllocateBufferIfNeeded() {
|
||||
if (!flat_data) {
|
||||
auto real_buffer = std::make_shared<std::string>();
|
||||
real_buffer->reserve(num_elements * ElementSize() + XNN_EXTRA_BYTES);
|
||||
flat_data = std::shared_ptr<char>(real_buffer, real_buffer->data());
|
||||
}
|
||||
}
|
||||
|
||||
void* Tensor::Data() {
|
||||
DCHECK(flat_data)
|
||||
<< "If this is weight, you may need to call one of the LoadFrom*()";
|
||||
return flat_data.get();
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> Tensor::Slice(DimsType offset) {
|
||||
DCHECK(flat_data);
|
||||
CHECK_EQ(offset.size(), dims.size()) << offset << " vs. " << dims;
|
||||
// offset: [0, k, 0, 0], dims: [1, K, _, _]. dims before k must be 1.
|
||||
bool found_non_zero_offset = false;
|
||||
int index_k = -1;
|
||||
for (int i = 0; i < dims.size(); ++i) {
|
||||
if (found_non_zero_offset) {
|
||||
DCHECK_EQ(offset[i], 0);
|
||||
} else if (offset[i] != 0) {
|
||||
found_non_zero_offset = true;
|
||||
index_k = i;
|
||||
}
|
||||
}
|
||||
DCHECK(found_non_zero_offset) << offset;
|
||||
|
||||
return Slice(index_k, offset[index_k]);
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> Tensor::Slice(size_t index, size_t offset) {
|
||||
size_t num_elements_offset = 1;
|
||||
DimsType new_dim = dims;
|
||||
for (int i = 0; i < dims.size(); ++i) {
|
||||
if (i < index) {
|
||||
DCHECK_EQ(dims[i], 1);
|
||||
} else if (i == index) {
|
||||
num_elements_offset *= offset;
|
||||
new_dim[i] = 1;
|
||||
} else {
|
||||
num_elements_offset *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
auto result = std::make_shared<Tensor>(std::move(new_dim), datatype);
|
||||
result->flat_data = std::shared_ptr<char>(
|
||||
flat_data, flat_data.get() + num_elements_offset * ElementSize());
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& Tensor::Borrow(std::shared_ptr<Tensor> other, size_t element_offset) {
|
||||
DCHECK_EQ(datatype, other->datatype);
|
||||
DCHECK_EQ(dims.size(), other->dims.size());
|
||||
flat_data = std::shared_ptr<char>(
|
||||
other->flat_data,
|
||||
other->flat_data.get() + element_offset * ElementSize());
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> Tensor::View() { return View(dims); }
|
||||
|
||||
std::shared_ptr<Tensor> Tensor::View(DimsType as_dims, size_t) {
|
||||
auto result = std::make_shared<Tensor>(as_dims, datatype);
|
||||
DCHECK_LE(result->num_elements, num_elements);
|
||||
result->flat_data = flat_data;
|
||||
return result;
|
||||
}
|
||||
|
||||
const void* Tensor::Data() const { return const_cast<Tensor*>(this)->Data(); }
|
||||
|
||||
absl::Status Tensor::DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags) {
|
||||
uint32_t id;
|
||||
RET_CHECK_EQ(xnn_status_success,
|
||||
xnn_define_tensor_value(&subgraph, datatype, dims.size(),
|
||||
dims.data(), /*data=*/nullptr,
|
||||
/*external_id=*/tensor_id, flags, &id));
|
||||
if (tensor_id == XNN_INVALID_VALUE_ID) {
|
||||
RET_CHECK_NE(id, XNN_INVALID_VALUE_ID);
|
||||
tensor_id = id;
|
||||
} else {
|
||||
RET_CHECK_EQ(id, tensor_id);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::DefineAsInput(xnn_subgraph& subgraph) {
|
||||
return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
|
||||
}
|
||||
|
||||
absl::Status Tensor::DefineAsOutput(xnn_subgraph& subgraph) {
|
||||
return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_OUTPUT);
|
||||
}
|
||||
|
||||
absl::Status Tensor::DefineAsIntermediateTensor(xnn_subgraph& subgraph) {
|
||||
RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID);
|
||||
return DefineAsExternal(subgraph, 0);
|
||||
}
|
||||
|
||||
absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_tensor_value(&subgraph, datatype, dims.size(), dims.data(),
|
||||
Data(), tensor_id, flags, &tensor_id));
|
||||
RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph) {
|
||||
RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID);
|
||||
return DefineWeight(subgraph, 0);
|
||||
}
|
||||
|
||||
absl::Status Tensor::DefineRope(xnn_subgraph& subgraph) {
|
||||
RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
|
||||
return DefineWeight(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
|
||||
}
|
||||
|
||||
absl::Status Tensor::LoadFromBuffer(const void* buffer) {
|
||||
AllocateBufferIfNeeded();
|
||||
memcpy(Data(), buffer, num_elements * ElementSize());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::LoadFromVec(const std::vector<float>& data,
|
||||
bool exact_match) {
|
||||
AllocateBufferIfNeeded();
|
||||
if (exact_match) {
|
||||
RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float));
|
||||
}
|
||||
|
||||
memcpy(Data(), data.data(), data.size() * sizeof(float));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::LoadFromVec(std::vector<float>&& data, bool exact_match) {
|
||||
if (exact_match) {
|
||||
RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float));
|
||||
}
|
||||
|
||||
auto real_buffer = std::make_shared<std::vector<float>>(std::move(data));
|
||||
if (real_buffer->size() < num_elements) {
|
||||
real_buffer->resize(num_elements);
|
||||
}
|
||||
flat_data = std::shared_ptr<char>(
|
||||
real_buffer, reinterpret_cast<char*>(real_buffer->data()));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::DumpToBuffer(void* buffer) {
|
||||
memcpy(buffer, Data(), num_elements * ElementSize());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::DumpToVec(std::vector<float>& out_data, bool exact_match) {
|
||||
if (exact_match) {
|
||||
RET_CHECK_EQ(num_elements * ElementSize(), out_data.size() * sizeof(float));
|
||||
} else {
|
||||
out_data.resize(num_elements);
|
||||
}
|
||||
memcpy(out_data.data(), Data(), num_elements * ElementSize());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Tensor::DumpToFile(absl::string_view file_path) {
|
||||
return file::SetContents(
|
||||
file_path,
|
||||
absl::string_view(flat_data.get(), num_elements * ElementSize()),
|
||||
file::Defaults());
|
||||
}
|
||||
|
||||
absl::Status Tensor::LoadFromFile(absl::string_view file_path, bool use_mmap,
|
||||
bool exact_match) {
|
||||
const size_t expected_size_in_bytes =
|
||||
exact_match ? num_elements * ElementSize() : 0;
|
||||
|
||||
ASSIGN_OR_RETURN(flat_data, LoadBufferFromFile(file_path, use_mmap,
|
||||
expected_size_in_bytes));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> Tensor::Transpose() {
|
||||
DCHECK_EQ(dims.size(), 2);
|
||||
DimsType out_dims{dims.rbegin(), dims.rend()};
|
||||
auto result = std::make_shared<Tensor>(std::move(out_dims), datatype);
|
||||
result->AllocateBufferIfNeeded();
|
||||
xnn_status s;
|
||||
const DimsType perm{1, 0};
|
||||
if (datatype == xnn_datatype_fp32) {
|
||||
s = xnn_run_transpose_nd_x32(Data(), result->Data(), dims.size(),
|
||||
dims.data(), perm.data(),
|
||||
/*flags=*/0, /*threadpool=*/nullptr);
|
||||
} else {
|
||||
LOG(FATAL) << "Need update to support new type";
|
||||
}
|
||||
DCHECK_EQ(s, xnn_status_success);
|
||||
return (s == xnn_status_success) ? result : nullptr;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> Tensor::ConvertToF32() {
|
||||
auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32);
|
||||
MP_RETURN_IF_ERROR(result->LoadFromBuffer(Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status QCTensor::LoadFromFile(absl::string_view quantized_weight_filename,
|
||||
absl::string_view scale_filename,
|
||||
bool use_mmap, bool exact_match) {
|
||||
size_t scale_element_size = dims[dim_scale];
|
||||
|
||||
ASSIGN_OR_RETURN(flat_data,
|
||||
LoadBufferFromFile(quantized_weight_filename, use_mmap,
|
||||
exact_match ? num_elements : 0));
|
||||
ASSIGN_OR_RETURN(scale_data,
|
||||
LoadBufferFromFile<float>(
|
||||
scale_filename, use_mmap,
|
||||
exact_match ? scale_element_size * sizeof(float) : 0));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status QCTensor::DumpToFile(absl::string_view file_path) {
|
||||
MP_RETURN_IF_ERROR(file::SetContents(
|
||||
file_path,
|
||||
absl::string_view(flat_data.get(), num_elements * ElementSize()),
|
||||
file::Defaults()));
|
||||
return file::SetContents(
|
||||
absl::StrCat(file_path, kQuantizedScaleSuffix),
|
||||
absl::string_view(reinterpret_cast<char*>(scale_data.get()),
|
||||
dims[dim_scale] * sizeof(float)),
|
||||
file::Defaults());
|
||||
}
|
||||
|
||||
absl::Status QCTensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
|
||||
RET_CHECK_EQ(
|
||||
xnn_status_success,
|
||||
xnn_define_channelwise_quantized_tensor_value(
|
||||
&subgraph, datatype, scale_data.get(), dims.size(), dim_scale,
|
||||
dims.data(), Data(), XNN_INVALID_VALUE_ID, flags, &tensor_id))
|
||||
<< *this;
|
||||
RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void QCTensor::AllocateBufferIfNeeded() {
|
||||
Tensor::AllocateBufferIfNeeded();
|
||||
if (!scale_data) {
|
||||
auto real_buffer = std::make_shared<std::vector<float>>();
|
||||
real_buffer->reserve(dims[dim_scale]);
|
||||
scale_data = std::shared_ptr<float>(real_buffer, real_buffer->data());
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> QCTensor::Transpose() {
|
||||
DCHECK_EQ(dims.size(), 2);
|
||||
size_t channel_size = dims[dim_scale];
|
||||
DimsType out_dims{dims.rbegin(), dims.rend()};
|
||||
auto result = std::make_shared<QCTensor>(std::move(out_dims), 1 - dim_scale);
|
||||
result->AllocateBufferIfNeeded();
|
||||
memcpy(result->scale_data.get(), scale_data.get(),
|
||||
channel_size * sizeof(float));
|
||||
xnn_status s;
|
||||
const DimsType perm{1, 0};
|
||||
if (datatype == xnn_datatype_qcint8) {
|
||||
s = xnn_run_transpose_nd_x8(Data(), result->Data(), dims.size(),
|
||||
dims.data(), perm.data(),
|
||||
/*flags=*/0, /*threadpool=*/nullptr);
|
||||
} else {
|
||||
LOG(FATAL) << "Need update to support new type";
|
||||
}
|
||||
DCHECK_EQ(s, xnn_status_success);
|
||||
return (s == xnn_status_success) ? result : nullptr;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> QCTensor::ConvertToF32() {
|
||||
auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32);
|
||||
// TODO: proper implement.
|
||||
LOG(WARNING) << "This is fake impl";
|
||||
MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> QCTensor::View(DimsType as_dims,
|
||||
size_t dim_scale_if_any) {
|
||||
auto result = std::make_shared<QCTensor>(as_dims, dim_scale_if_any);
|
||||
DCHECK_LE(result->num_elements, num_elements);
|
||||
result->flat_data = flat_data;
|
||||
result->scale_data = scale_data;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
202
mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h
Normal file
202
mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h
Normal file
|
@ -0,0 +1,202 @@
|
|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "file/base/helpers.h"
|
||||
#include "file/base/options.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
|
||||
#include "third_party/XNNPACK/include/xnnpack.h"
|
||||
#include "util/gtl/stl_logging.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace xnn_utils {
|
||||
|
||||
static constexpr absl::string_view kQuantizedScaleSuffix{"_quantized_scale"};
|
||||
static constexpr absl::string_view kSparsityParamsSuffix{"_sparsity_params"};
|
||||
|
||||
struct Tensor {
|
||||
using DimsType = std::vector<size_t>;
|
||||
|
||||
explicit Tensor(DimsType in_dims, xnn_datatype datatype_ = xnn_datatype_fp32)
|
||||
: dims(std::move(in_dims)),
|
||||
num_elements(dims.empty() ? 0
|
||||
: std::accumulate(std::begin(dims),
|
||||
std::end(dims), size_t(1),
|
||||
std::multiplies<size_t>())),
|
||||
datatype(datatype_) {}
|
||||
Tensor(Tensor&& other) = default;
|
||||
|
||||
Tensor& operator=(const Tensor& other) = delete;
|
||||
Tensor& operator=(Tensor&& other) = default;
|
||||
|
||||
virtual ~Tensor() = default;
|
||||
|
||||
bool operator==(const Tensor& other) const;
|
||||
|
||||
void SetMetadata(absl::string_view key, int value) { metadata[key] = value; }
|
||||
|
||||
std::optional<int> GetMetadata(absl::string_view key) const {
|
||||
if (metadata.contains(key)) {
|
||||
return metadata.at(key);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Read weights from file.
|
||||
template <xnn_datatype xnn_datatype_ = xnn_datatype_fp32>
|
||||
static absl::StatusOr<std::shared_ptr<Tensor>> FromFile(
|
||||
absl::string_view file_path, DimsType dims, bool use_mmap = true) {
|
||||
auto result = std::make_shared<Tensor>(std::move(dims), xnn_datatype_);
|
||||
|
||||
MP_RETURN_IF_ERROR(
|
||||
result->LoadFromFile(file_path, use_mmap, /*exact_match=*/true));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual absl::Status DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags);
|
||||
absl::Status DefineAsInput(xnn_subgraph& subgraph);
|
||||
absl::Status DefineAsOutput(xnn_subgraph& subgraph);
|
||||
absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph);
|
||||
virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags);
|
||||
absl::Status DefineWeight(xnn_subgraph& subgraph);
|
||||
absl::Status DefineRope(xnn_subgraph& subgraph);
|
||||
|
||||
absl::Status LoadFromBuffer(const void* buffer);
|
||||
absl::Status LoadFromVec(const std::vector<float>& data,
|
||||
bool exact_match = true);
|
||||
absl::Status LoadFromVec(std::vector<float>&& data, bool exact_match = true);
|
||||
absl::Status LoadFromFile(absl::string_view file_path) {
|
||||
return LoadFromFile(file_path, true, true);
|
||||
}
|
||||
virtual absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
|
||||
bool exact_match);
|
||||
|
||||
absl::Status DumpToBuffer(void* buffer);
|
||||
absl::Status DumpToVec(std::vector<float>& out_data, bool exact_match = true);
|
||||
virtual absl::Status DumpToFile(absl::string_view file_path);
|
||||
|
||||
// If ith offset is 0, view's ith dim equals to original ith dim, otherwise 1.
|
||||
std::shared_ptr<Tensor> Slice(DimsType offset);
|
||||
// Slice along the `index`th dimension, offset at this dimension.
|
||||
std::shared_ptr<Tensor> Slice(size_t index, size_t offset);
|
||||
|
||||
// Point the underline data to the borrowed tensor's data.
|
||||
Tensor& Borrow(std::shared_ptr<Tensor>, size_t element_offset = 0);
|
||||
std::shared_ptr<Tensor> View();
|
||||
virtual std::shared_ptr<Tensor> View(DimsType as_dims,
|
||||
size_t dim_scale_if_any = 0);
|
||||
|
||||
Tensor& MarkOutput() {
|
||||
AllocateBufferIfNeeded();
|
||||
is_output_tensor = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
virtual void* Data();
|
||||
const void* Data() const;
|
||||
|
||||
template <typename T>
|
||||
T* DataAs() {
|
||||
DCHECK_EQ(ElementSize(), sizeof(T));
|
||||
return static_cast<T*>(Data());
|
||||
}
|
||||
template <typename T>
|
||||
const T* DataAs() const {
|
||||
return static_cast<const T*>(Data());
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<Tensor> Transpose();
|
||||
|
||||
virtual absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32();
|
||||
|
||||
DimsType dims;
|
||||
size_t num_elements = 0;
|
||||
xnn_datatype datatype = xnn_datatype_invalid;
|
||||
uint32_t tensor_id = XNN_INVALID_VALUE_ID;
|
||||
|
||||
// shared_ptr to make TensorMetadata copyable.
|
||||
std::shared_ptr<char> flat_data;
|
||||
|
||||
protected:
|
||||
friend class XnnGraphBuilder;
|
||||
friend class XnnGraph;
|
||||
|
||||
// Actually allocate buffer unless necessary.
|
||||
virtual void AllocateBufferIfNeeded();
|
||||
|
||||
virtual size_t ElementSize() const { return 4; }
|
||||
|
||||
bool is_output_tensor = false;
|
||||
|
||||
absl::flat_hash_map<std::string, int> metadata;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
|
||||
|
||||
// Channelwise Quantized.
|
||||
struct QCTensor : public Tensor {
|
||||
explicit QCTensor(DimsType in_dims, size_t dim_scale_if_any)
|
||||
: Tensor(std::move(in_dims)), dim_scale(dim_scale_if_any) {
|
||||
datatype = xnn_datatype_qcint8;
|
||||
CHECK_LT(dim_scale, 4);
|
||||
}
|
||||
|
||||
void AllocateBufferIfNeeded() override;
|
||||
size_t ElementSize() const override { return 1; }
|
||||
|
||||
virtual absl::Status LoadFromFile(absl::string_view quantized_weight_filename,
|
||||
absl::string_view scale_filename,
|
||||
bool use_mmap, bool exact_match);
|
||||
// Append kQuantizedScaleSuffix to use as scale filename.
|
||||
absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
|
||||
bool exact_match) override {
|
||||
return LoadFromFile(file_path,
|
||||
absl::StrCat(file_path, kQuantizedScaleSuffix),
|
||||
use_mmap, exact_match);
|
||||
}
|
||||
|
||||
absl::Status DumpToFile(absl::string_view file_path) override;
|
||||
|
||||
absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags) override;
|
||||
|
||||
std::shared_ptr<Tensor> Transpose() override;
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32() override;
|
||||
|
||||
std::shared_ptr<Tensor> View(DimsType as_dims,
|
||||
size_t dim_scale_if_any) override;
|
||||
|
||||
std::shared_ptr<float> scale_data;
|
||||
// Index of the dimension to scale.
|
||||
size_t dim_scale;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const QCTensor& tensor);
|
||||
|
||||
absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos);
|
||||
|
||||
} // namespace xnn_utils
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
|
Loading…
Reference in New Issue
Block a user