Internal Change

PiperOrigin-RevId: 547265380
This commit is contained in:
MediaPipe Team 2023-07-11 12:32:14 -07:00 committed by Copybara-Service
parent e4ec4d2526
commit 4788fddde9
11 changed files with 2978 additions and 0 deletions

View File

@ -0,0 +1 @@
# Utilities needed to interacte with XNNPACK.

View File

@ -0,0 +1,887 @@
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/source_location.h"
#include "file/base/helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
#include "third_party/XNNPACK/include/xnnpack.h"
#include "util/gtl/stl_logging.h"
namespace mediapipe {
namespace xnn_utils {
namespace {
// XNNPACK supports broadcasting, this function inferences the output shape
// based on input tensor shapes.
std::vector<size_t> OutDimsForElementwiseOp(const Tensor& lhs,
const Tensor& rhs) {
DCHECK(!lhs.dims.empty());
DCHECK(!rhs.dims.empty());
std::vector<size_t> lhs_dims_rev(lhs.dims.rbegin(), lhs.dims.rend());
std::vector<size_t> rhs_dims_rev(rhs.dims.rbegin(), rhs.dims.rend());
DCHECK([&]() -> bool {
for (size_t i = 0; i < std::min(lhs_dims_rev.size(), rhs_dims_rev.size());
++i) {
if ((lhs_dims_rev[i] != rhs_dims_rev[i]) && (lhs_dims_rev[i] != 1) &&
(rhs_dims_rev[i] != 1)) {
return false;
}
}
return true;
}()) << "lhs "
<< lhs.dims << " rhs " << rhs.dims;
std::vector<size_t> out_dims(
std::max(lhs_dims_rev.size(), rhs_dims_rev.size()));
for (int i = 0; i < out_dims.size(); ++i) {
if (lhs_dims_rev.size() <= i) {
out_dims[i] = rhs_dims_rev[i];
} else if (rhs_dims_rev.size() <= i) {
out_dims[i] = lhs_dims_rev[i];
} else {
out_dims[i] = lhs_dims_rev[i] == 1 ? rhs_dims_rev[i] : lhs_dims_rev[i];
}
}
return std::vector<size_t>(out_dims.rbegin(), out_dims.rend());
}
// If out_id is invalid, we need to allocate tensor for intermediate result.
// Otherwise, set out_id in out_metadata.
absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph,
uint32_t out_id,
Tensor& out_metadata) {
RET_CHECK_GT(out_metadata.dims.size(), 0);
if (out_id == XNN_INVALID_VALUE_ID) {
// The output is intermediate, thus allocate tensor.
MP_RETURN_IF_ERROR(out_metadata.DefineAsIntermediateTensor(*subgraph));
} else {
out_metadata.tensor_id = out_id;
}
return absl::OkStatus();
}
absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph,
Tensor& out_metadata) {
return MaybeAllocateIntermediateTensor(subgraph, out_metadata.tensor_id,
out_metadata);
}
absl::Status AllocateIntermediateTensor(xnn_subgraph_t subgraph,
Tensor& out_metadata) {
return MaybeAllocateIntermediateTensor(subgraph, XNN_INVALID_VALUE_ID,
out_metadata);
}
// 1.0/jax.nn.softplus(0.0) = 1.442695041
// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1])
void SoftPlus(size_t cnt, const std::vector<size_t>& query_dims, float* weight,
float* scale) {
constexpr double r_softplus_0 = 1.442695041;
// softplus(x) = np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)
// scale = softplus(per_dim_scale) / (sqrt(input.dims[-1]) * softplus(0))
const double r_softplus_0_over_sqrt_d =
r_softplus_0 / std::sqrt(query_dims.back());
for (int i = 0; i < cnt; ++i) {
scale[i] = log1p(exp(-abs(weight[i]))) + fmax(weight[i], 0.0f);
scale[i] *= r_softplus_0_over_sqrt_d;
}
}
} // namespace
absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build(
std::unique_ptr<RuntimeConfigs> runtime_configs) {
if (!runtime_configs) {
runtime_configs = std::make_unique<RuntimeConfigs>();
runtime_configs->xnn_num_threads = 1;
runtime_configs->xnn_profile = false;
}
VLOG(2) << "XnnGraphBuilder::Build() building...";
auto build_begin = absl::Now();
RET_CHECK_EQ(xnn_status_success, xnn_initialize(nullptr));
absl::flat_hash_set<std::shared_ptr<Tensor>> output_tensors;
{
uint32_t cnt = input_tensors_.size();
for (auto& t : interm_tensors_) {
if (t->is_output_tensor) {
RET_CHECK_EQ(t->tensor_id, XNN_INVALID_VALUE_ID);
t->tensor_id = cnt++;
output_tensors.insert(t);
}
}
for (auto& t : output_tensors) {
interm_tensors_.erase(t);
}
for (auto& t : rope_weigths_) {
interm_tensors_.erase(t);
t->tensor_id = cnt++;
}
}
xnn_subgraph_t subgraph_ptr = nullptr;
RET_CHECK_EQ(xnn_status_success,
xnn_create_subgraph(
/*external_value_ids=*/input_tensors_.size() +
output_tensors.size() + rope_weigths_.size(),
/*flags=*/0, &subgraph_ptr));
RET_CHECK_NE(subgraph_ptr, nullptr);
XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph};
for (auto& input : input_tensors_) {
MP_RETURN_IF_ERROR(input->DefineAsInput(*subgraph));
}
for (auto& output : output_tensors) {
MP_RETURN_IF_ERROR(output->DefineAsOutput(*subgraph));
}
{
for (auto& t : rope_weigths_) {
MP_RETURN_IF_ERROR(t->DefineRope(*subgraph));
}
}
for (auto& [loc, step] : build_steps_) {
if (auto s = step(subgraph.get()); !s.ok()) {
s.AddSourceLocation(loc);
return s;
}
}
XnnGraph result(std::move(subgraph), std::move(runtime_configs));
result.input_tensors_ = std::move(input_tensors_);
result.output_tensors_ = std::move(output_tensors);
result.interm_tensors_ = std::move(interm_tensors_);
VLOG(2) << "XnnGraphBuilder::Build() creating runtime...";
auto create_begin = absl::Now();
MP_RETURN_IF_ERROR(result.CreateRuntime());
VLOG(2) << "XnnGraphBuilder::Build() setting up runtime...";
auto setup_begin = absl::Now();
MP_RETURN_IF_ERROR(result.SetupRuntime());
auto end = absl::Now();
VLOG(2) << "XnnGraphBuilder::Build() done build, Total " << end - build_begin
<< ", create runtime " << setup_begin - create_begin
<< ", setup runtime " << end - setup_begin;
return std::make_unique<XnnGraph>(std::move(result));
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewInput(
Tensor::DimsType dims, absl::SourceLocation loc) {
auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
t->AllocateBufferIfNeeded();
t->tensor_id = input_tensors_.size();
input_tensors_.insert(t);
return t;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewWeight(
absl::string_view file_path, Tensor::DimsType dims,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto t, NewWeight(std::move(dims)));
MP_RETURN_IF_ERROR(t->LoadFromFile(file_path));
return t;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewWeight(
Tensor::DimsType dims, absl::SourceLocation loc) {
auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
NewWeight(t, loc);
return t;
}
void XnnGraphBuilder::NewWeight(std::shared_ptr<Tensor> t,
absl::SourceLocation loc) {
build_steps_.push_back(
{loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status {
if (interm_tensors_.contains(t)) {
MP_RETURN_IF_ERROR(t->DefineWeight(*subgraph));
}
return absl::OkStatus();
}});
interm_tensors_.insert(t);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::IntermediateTensor(
Tensor::DimsType dims, absl::SourceLocation loc) {
auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
build_steps_.push_back(
{loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status {
// Could be moved to output tensors, thus need check.
if (interm_tensors_.contains(t)) {
return AllocateIntermediateTensor(subgraph, *t);
}
return absl::OkStatus();
}});
interm_tensors_.insert(t);
return t;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Reshape(
std::shared_ptr<Tensor> input, Tensor::DimsType new_dims,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims)));
RET_CHECK_EQ(input->num_elements, output->num_elements)
<< "otherwise reshape does not make sense.";
build_steps_.push_back(
{loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
subgraph, output->tensor_id, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_static_reshape(
subgraph, output->dims.size(), output->dims.data(),
input->tensor_id, output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::FullConn(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
std::shared_ptr<Tensor> bias, FullConnParams params,
absl::SourceLocation loc) {
const auto& input_dim = input->dims;
const auto& weight_dim = weight->dims;
DCHECK_GT(input_dim.size(), 1);
DCHECK_GE(weight_dim.size(), 2);
if (weight_dim.size() == 3) {
RET_CHECK_EQ(weight_dim[0], 1);
} else if (weight_dim.size() == 4) {
RET_CHECK_EQ(weight_dim[0], 1);
RET_CHECK_EQ(weight_dim[1], 1);
}
if (bias) {
RET_CHECK_LE(bias->dims.size(), 1);
}
Tensor::DimsType out_dims = input_dim;
// Not considering reshape 2D
if (params.transpose) {
RET_CHECK_EQ(weight_dim.size(), 2) << "otherwise change following line";
RET_CHECK_EQ(input_dim.back(), *(weight_dim.end() - 2));
out_dims.back() = weight_dim.back();
} else {
RET_CHECK_EQ(input_dim.back(), weight_dim.back());
out_dims.pop_back();
for (size_t i = 0; i < weight_dim.size() - 1; ++i) {
// NHD . BTD -> NHBT
out_dims.push_back(weight_dim[i]);
}
}
ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(out_dims)));
build_steps_.push_back(
{loc,
[this, input, weight, bias, params,
output](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
subgraph, output->tensor_id, *output));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_fully_connected(
subgraph, params.out_min, params.out_max, input->tensor_id,
weight->tensor_id,
bias ? bias->tensor_id : XNN_INVALID_VALUE_ID,
output->tensor_id,
/*flags=*/params.transpose ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Permute(
std::shared_ptr<Tensor> input, Tensor::DimsType permute,
absl::SourceLocation loc) {
RET_CHECK_EQ(input->dims.size(), permute.size());
const auto& old_dims = input->dims;
std::vector<size_t> new_dims;
for (size_t i = 0; i < permute.size(); ++i) {
new_dims.push_back(old_dims[permute[i]]);
}
ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims)));
build_steps_.push_back(
{loc,
[this, permute, input, output](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_static_transpose(
subgraph, permute.size(), permute.data(),
input->tensor_id, output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Square(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
build_steps_.push_back(
{loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
subgraph, output->tensor_id, *output));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_square(subgraph, input->tensor_id, output->tensor_id,
/*flags=*/0));
return absl::Status();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Softmax(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
build_steps_.push_back(
{loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
subgraph, output->tensor_id, *output));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_softmax(subgraph, input->tensor_id, output->tensor_id,
/*flags=*/0));
return absl::Status();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::SquareRoot(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
build_steps_.push_back(
{loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
subgraph, output->tensor_id, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_square_root(subgraph, input->tensor_id,
output->tensor_id,
/*flags=*/0));
return absl::Status();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::AvgLastDim(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto before_reshape,
IntermediateTensor(Tensor::DimsType{input->dims.begin(),
input->dims.end() - 1}));
build_steps_.push_back(
{loc,
[this, input, before_reshape](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
subgraph, before_reshape->tensor_id, *before_reshape));
size_t reduction_axis = input->dims.size() - 1;
RET_CHECK_EQ(
xnn_status_success,
xnn_define_static_mean(subgraph, 1, &reduction_axis,
input->tensor_id, before_reshape->tensor_id,
/*flags=*/0));
return absl::OkStatus();
}});
Tensor::DimsType new_dims = input->dims;
new_dims.back() = 1;
return Reshape(before_reshape, std::move(new_dims));
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rms(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto sqr_out, Square(input, loc));
ASSIGN_OR_RETURN(auto mean_out, AvgLastDim(sqr_out, loc));
return SquareRoot(mean_out, loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::RmsNorm(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto rms_out, Rms(input));
ASSIGN_OR_RETURN(auto clamped_rms, Clamp(rms_out, {.out_min = 1e-6}));
// div_out = input / rms
ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms));
// div_out * (1 + scale) = div_out + div_out * scale
ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale));
return ElementAdd(div_out, normed_div_out);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
return ElementAdd(lhs, rhs_tensor, params, loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
ClampParams params, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output,
IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
build_steps_.push_back(
{loc,
[this, lhs, rhs, output,
params](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_add2(subgraph, params.out_min, params.out_max,
lhs->tensor_id, rhs->tensor_id,
output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementMul(
std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
return ElementMul(lhs, rhs_tensor, params, loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementMul(
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
ClampParams params, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output,
IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
build_steps_.push_back(
{loc,
[this, lhs, rhs, output,
params](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_multiply2(subgraph, params.out_min, params.out_max,
lhs->tensor_id, rhs->tensor_id,
output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementDiv(
std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
return ElementDiv(lhs, rhs_tensor, params, loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementDiv(
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
ClampParams params, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output,
IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
build_steps_.push_back(
{loc,
[this, lhs, rhs, output,
params](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_divide(subgraph, params.out_min, params.out_max,
lhs->tensor_id, rhs->tensor_id,
output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
// TODO: write an op?
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::PerDimScale(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> per_dim_scale,
absl::SourceLocation loc) {
// input: B T N H
// 1/softplus(0) = 1.442695041
// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1])
// query = query * scale
const auto& input_dim = input->dims;
DCHECK_GE(input_dim.size(), 1);
const size_t H = input_dim.back();
if (!per_dim_scale_cache_.contains(H) ||
!per_dim_scale_cache_[H].contains(per_dim_scale.get())) {
ASSIGN_OR_RETURN(auto cached_pds, NewWeight(per_dim_scale->dims));
auto* pds_in = static_cast<float*>(per_dim_scale->Data());
std::vector<float> pds_scaled(per_dim_scale->num_elements);
SoftPlus(per_dim_scale->num_elements, input_dim, pds_in, pds_scaled.data());
MP_RETURN_IF_ERROR(cached_pds->LoadFromVec(std::move(pds_scaled)));
per_dim_scale_cache_[H][per_dim_scale.get()] = cached_pds;
}
return ElementMul(input, per_dim_scale_cache_[H][per_dim_scale.get()]);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rope(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos,
absl::SourceLocation loc) {
// TODO: seg_pos should not be weight.
rope_weigths_.insert(segment_pos);
const auto& input_dim = input->dims;
const auto& segment_pos_dim = segment_pos->dims;
// B T N H
RET_CHECK_EQ(input_dim.size(), 4) << "xnn requirement";
// S H
RET_CHECK_EQ(segment_pos_dim.size(), 2) << "xnn requirement";
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input_dim));
const auto input_seq_size = input_dim[1];
RET_CHECK_LE(input_seq_size, segment_pos_dim[0]);
const auto head_dim_H = input_dim[3];
RET_CHECK_EQ(head_dim_H, segment_pos_dim[1]);
build_steps_.push_back(
{loc,
[this, input, output, segment_pos,
input_seq_size](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(
xnn_status_success,
xnn_define_rope(subgraph, input_seq_size, input->tensor_id,
segment_pos->tensor_id, output->tensor_id,
/*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::BatchMatMul(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
FullConnParams params, absl::SourceLocation loc) {
const auto& lhs_dim = input->dims;
const auto& rhs_dim = weight->dims;
// [B, N, T, H] . [B, N, S, H], N == 12, B == 1
DCHECK_EQ(lhs_dim.size(), 4);
DCHECK_EQ(rhs_dim.size(), 4);
DCHECK_EQ(lhs_dim.back(), rhs_dim.back());
DCHECK_EQ(lhs_dim.back(), rhs_dim.back());
constexpr size_t num_slices = 12;
DCHECK_EQ(lhs_dim[1], num_slices);
DCHECK_EQ(rhs_dim[1], num_slices);
const size_t S = rhs_dim[2];
const size_t T = lhs_dim[2];
const size_t batch_size = lhs_dim[0] * lhs_dim[1];
DCHECK_EQ(batch_size, rhs_dim[0] * rhs_dim[1]);
DCHECK_EQ(batch_size, 12);
ASSIGN_OR_RETURN(auto output, IntermediateTensor({1, 12, T, S}));
build_steps_.push_back(
{loc, [input, output, weight](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_batch_matrix_multiply(
subgraph, input->tensor_id, weight->tensor_id,
output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Tanh(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
build_steps_.push_back(
{loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_tanh(subgraph, input->tensor_id,
output->tensor_id, /*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::CapTanh(
std::shared_ptr<Tensor> input, float cap, absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto div, ElementDiv(input, cap));
ASSIGN_OR_RETURN(auto tanh, Tanh(div));
return ElementMul(tanh, cap);
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::DotAttention(
std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
std::shared_ptr<Tensor> per_dim_scale, absl::SourceLocation loc) {
// BTNH
ASSIGN_OR_RETURN(auto query_after_scale,
PerDimScale(query_proj, per_dim_scale));
// Dot similarity
// BTNH -> BNTH
ASSIGN_OR_RETURN(auto query_permuted,
Permute(query_after_scale, {0, 2, 1, 3}));
// BSNH -> BNSH
ASSIGN_OR_RETURN(auto key_permuted, Permute(key_proj, {0, 2, 1, 3}));
// einsum(BNTH.BNSH -> BNTS)
ASSIGN_OR_RETURN(auto logits, BatchMatMul(query_permuted, key_permuted));
// Cap, mask
ASSIGN_OR_RETURN(auto cap_logits, CapTanh(logits, 50));
ASSIGN_OR_RETURN(auto padded_logits, ElementAdd(atten_mask, cap_logits));
ASSIGN_OR_RETURN(auto probs, Softmax(padded_logits));
ASSIGN_OR_RETURN(auto value_permuted, Permute(value_proj, {0, 2, 3, 1}));
// Outcome
// BNTS.BNHS -> BNTH
ASSIGN_OR_RETURN(auto outcome_before_permute,
BatchMatMul(probs, value_permuted));
// [B, N, T, H] -> BTNH
return Permute(outcome_before_permute, {0, 2, 1, 3});
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::SelfAttentionProj(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
absl::SourceLocation loc) {
const auto& input_dim = input->dims;
const auto& weight_dim = weight->dims;
size_t N = 0, H = 0;
RET_CHECK_EQ(input_dim.size(), 3) << "BTD";
std::optional<int> reshaped_N =
weight->GetMetadata(kKeySelfAttentionReshapedWeight);
RET_CHECK(reshaped_N && *reshaped_N)
<< "We rely on " << kKeySelfAttentionReshapedWeight << " to get N";
RET_CHECK_EQ(weight_dim.size(), 2) << "NH,D";
N = *reshaped_N;
H = weight_dim[0] / N;
// out: B,T,NH
ASSIGN_OR_RETURN(auto proj, MatMul(input, weight));
// B,T,NH -> B,T,N,H
return Reshape(proj, {input_dim[0], input_dim[1], N, H});
}
absl::Status XnnGraph::CreateRuntime() {
RET_CHECK_EQ(runtime_.get(), nullptr);
xnn_runtime_t runtime_ptr = nullptr;
uint32_t flags = 0;
if (runtime_configs_->xnn_profile) {
flags |= XNN_FLAG_BASIC_PROFILING;
if (!runtime_configs_->xnn_profile_csv.empty()) {
MP_RETURN_IF_ERROR(file::SetContents(runtime_configs_->xnn_profile_csv,
"node_id; time(us); op_name\n",
file::Defaults()));
}
}
pthreadpool_t threadpool =
pthreadpool_create(runtime_configs_->xnn_num_threads);
threadpool_ = XnnThreadpoolPtr{threadpool, pthreadpool_destroy};
RET_CHECK_EQ(xnn_status_success,
xnn_create_runtime_v2(owned_subgraph_.get(), threadpool, flags,
&runtime_ptr));
RET_CHECK_NE(runtime_ptr, nullptr);
runtime_ = XnnRuntimePtr{runtime_ptr, xnn_delete_runtime};
return absl::OkStatus();
}
absl::Status XnnGraph::SetupRuntime() {
{
VLOG(3) << "input size " << input_tensors_.size();
VLOG(3) << "output size " << output_tensors_.size();
VLOG(3) << "rope size " << rope_weigths_.size();
externals_.clear();
// Init external
for (const auto& input : input_tensors_) {
VLOG(3) << "input id " << input->tensor_id;
externals_.push_back(xnn_external_value{input->tensor_id, input->Data()});
}
for (const auto& output : output_tensors_) {
VLOG(3) << "output id " << output->tensor_id;
externals_.push_back(
xnn_external_value{output->tensor_id, output->Data()});
}
for (const auto& t : rope_weigths_) {
VLOG(3) << "rope id " << t->tensor_id;
}
}
RET_CHECK_EQ(
xnn_status_success,
xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()));
return absl::OkStatus();
}
absl::Status XnnGraph::Run() {
RET_CHECK(runtime_);
RET_CHECK_EQ(xnn_status_success, xnn_invoke_runtime(runtime_.get()));
if (runtime_configs_->xnn_profile) {
size_t required_size = 0;
// xnn_get_runtime_profiling_info is called twice. The first time it sets
// required_size to the required size of the buffer to store the result and
// returns xnn_status_out_of_memory. The second time it writes the result to
// the buffer provided that the buffer is large enough and returns
// xnn_status_success.
xnn_status status = xnn_get_runtime_profiling_info(
runtime_.get(), xnn_profile_info_operator_name, /*param_value_size*/ 0,
/*param_value*/ nullptr, &required_size);
std::vector<char> operator_names;
if (status == xnn_status_out_of_memory) {
operator_names.resize(required_size);
status = xnn_get_runtime_profiling_info(
runtime_.get(), xnn_profile_info_operator_name, operator_names.size(),
operator_names.data(), &required_size);
}
RET_CHECK_EQ(status, xnn_status_success);
size_t num_operators;
status = xnn_get_runtime_profiling_info(
runtime_.get(), xnn_profile_info_num_operators, sizeof(num_operators),
&num_operators, &required_size);
RET_CHECK_EQ(status, xnn_status_success);
status = xnn_get_runtime_profiling_info(
runtime_.get(), xnn_profile_info_operator_timing,
/*param_value_size*/ 0,
/*param_value*/ nullptr, &required_size);
std::vector<uint64_t> operator_timings;
if (status == xnn_status_out_of_memory) {
operator_timings.resize(required_size / sizeof(uint64_t));
status = xnn_get_runtime_profiling_info(
runtime_.get(), xnn_profile_info_operator_timing,
operator_timings.size() * sizeof(uint64_t), operator_timings.data(),
&required_size);
}
RET_CHECK_EQ(status, xnn_status_success);
const char* operator_name = nullptr;
size_t name_len = 0;
std::stringstream ss;
for (size_t node_index = 0; node_index < num_operators; ++node_index) {
operator_name = &operator_names[name_len];
name_len += strlen(operator_name) + 1;
VLOG(2) << "XnnGraphBuilder::Profile() node_index: " << node_index
<< ", time: " << operator_timings[node_index] << " us, "
<< operator_name << "\n";
if (!runtime_configs_->xnn_profile_csv.empty()) {
// Use ';' instead of ',' because operator_name contains comma.
ss << node_index << "; " << operator_timings[node_index] << "; "
<< operator_name << "\n";
}
}
if (!runtime_configs_->xnn_profile_csv.empty()) {
MP_RETURN_IF_ERROR(file::AppendStringToFile(
runtime_configs_->xnn_profile_csv, ss.str(), file::Defaults()));
}
}
return absl::OkStatus();
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Clamp(
std::shared_ptr<Tensor> input, ClampParams params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
build_steps_.push_back(
{loc,
[this, input, output, params](xnn_subgraph_t subgraph) -> absl::Status {
MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
RET_CHECK_EQ(xnn_status_success,
xnn_define_clamp(subgraph, params.out_min, params.out_max,
input->tensor_id, output->tensor_id,
/*flags=*/0));
return absl::OkStatus();
}});
return output;
}
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Gelu(
std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
// x^2
ASSIGN_OR_RETURN(auto sqr_out, Square(input));
// 0.044715 * x^2
ASSIGN_OR_RETURN(auto sqr_4471, ElementMul(sqr_out, 0.044715));
// 1 + 0.044715 * x^2
ASSIGN_OR_RETURN(auto sqr_4471_1, ElementAdd(sqr_4471, 1.0f));
// x + 0.044715 * x^3
ASSIGN_OR_RETURN(auto x_cube_4471, ElementMul(sqr_4471_1, input));
constexpr float sqrt_2_over_pi = 0.7978845608;
ASSIGN_OR_RETURN(auto sqrt_2_over_pi_x_cube_4471,
ElementMul(x_cube_4471, sqrt_2_over_pi));
// tanh(x + 0.044715 * x^3)
ASSIGN_OR_RETURN(auto tanh_x_cube_4471, Tanh(sqrt_2_over_pi_x_cube_4471));
// 1 + tanh(x + 0.044715 * x^3)
ASSIGN_OR_RETURN(auto tanh_x_cube_4471_1, ElementAdd(tanh_x_cube_4471, 1.0f));
// 0.5 * (1 + [tanh(x + 0.044715 * x^3)])
ASSIGN_OR_RETURN(auto cdf, ElementMul(tanh_x_cube_4471_1, 0.5));
return ElementMul(input, cdf);
}
} // namespace xnn_utils
} // namespace mediapipe

View File

@ -0,0 +1,288 @@
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
#include <sys/types.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <initializer_list>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/source_location.h"
#include "file/base/helpers.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
#include "third_party/XNNPACK/include/xnnpack.h"
namespace mediapipe {
namespace xnn_utils {
using XnnSubgraphPtr =
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)>;
using XnnRuntimePtr =
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>;
using XnnThreadpoolPtr =
std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)>;
struct ClampParams {
float out_min = -std::numeric_limits<float>::infinity();
float out_max = std::numeric_limits<float>::infinity();
};
struct FullConnParams : public ClampParams {
bool transpose = false;
};
struct RuntimeConfigs {
bool xnn_profile;
std::string xnn_profile_csv;
size_t xnn_num_threads;
};
class XnnGraph;
// XnnGraphBuilder is used to construct XnnGraph (through Build()). Once a
// XnnGraph is constructed, it can run for multiple times.
class XnnGraphBuilder {
public:
static constexpr absl::string_view kKeySelfAttentionReshapedWeight{
"self_attention_reshaped_weight_N"};
explicit XnnGraphBuilder(xnn_datatype data_type = xnn_datatype_fp32)
: data_type_(data_type) {}
virtual ~XnnGraphBuilder() = default;
absl::StatusOr<std::unique_ptr<XnnGraph>> Build(
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
// New input or output tensor.
absl::StatusOr<std::shared_ptr<Tensor>> NewInput(
Tensor::DimsType dims,
absl::SourceLocation loc = absl::SourceLocation::current());
// New static weight, populate value before Build()
absl::StatusOr<std::shared_ptr<Tensor>> NewWeight(
Tensor::DimsType dims,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> NewWeight(
absl::string_view file_path, Tensor::DimsType dims,
absl::SourceLocation loc = absl::SourceLocation::current());
void NewWeight(std::shared_ptr<Tensor> t,
absl::SourceLocation loc = absl::SourceLocation::current());
// Element wise square.
absl::StatusOr<std::shared_ptr<Tensor>> Square(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> SquareRoot(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Gelu(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Clamp(
std::shared_ptr<Tensor> input, ClampParams params,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Tanh(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
// logits = cap * jnp.tanh(logits / cap)
absl::StatusOr<std::shared_ptr<Tensor>> CapTanh(
std::shared_ptr<Tensor> input, float cap,
absl::SourceLocation loc = absl::SourceLocation::current());
// Average over last dimension, keep num of dims same.
absl::StatusOr<std::shared_ptr<Tensor>> AvgLastDim(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Rms(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> RmsNorm(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Reshape(
std::shared_ptr<Tensor> input, Tensor::DimsType new_dims,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Permute(
std::shared_ptr<Tensor> input, Tensor::DimsType permute,
absl::SourceLocation loc = absl::SourceLocation::current());
// input: [B * I]
// filter: [O * I], [I * O] if transpose
// return: [B * O]
absl::StatusOr<std::shared_ptr<Tensor>> MatMul(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return MatMul(input, weight, FullConnParams(), loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> MatMul(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
FullConnParams params,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return FullConn(input, weight, nullptr, params, loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> BatchMatMul(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
FullConnParams params = FullConnParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> FullConn(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
std::shared_ptr<Tensor> bias,
absl::SourceLocation loc = absl::SourceLocation::current()) {
return FullConn(input, weight, bias, FullConnParams(), loc);
}
absl::StatusOr<std::shared_ptr<Tensor>> FullConn(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
std::shared_ptr<Tensor> bias, FullConnParams params,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Softmax(
std::shared_ptr<Tensor> input,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionProj(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> ElementAdd(
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
ClampParams params = ClampParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> ElementAdd(
std::shared_ptr<Tensor> lhs, float rhs,
ClampParams params = ClampParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> ElementMul(
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
ClampParams params = ClampParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> ElementMul(
std::shared_ptr<Tensor> lhs, float rhs,
ClampParams params = ClampParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> ElementDiv(
std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
ClampParams params = ClampParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> ElementDiv(
std::shared_ptr<Tensor> lhs, float rhs,
ClampParams params = ClampParams(),
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> Rope(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> PerDimScale(
std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> per_dim_scale,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> DotAttention(
std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
std::shared_ptr<Tensor> per_dim_scale,
absl::SourceLocation loc = absl::SourceLocation::current());
protected:
absl::StatusOr<std::shared_ptr<Tensor>> IntermediateTensor(
Tensor::DimsType dims,
absl::SourceLocation loc = absl::SourceLocation::current());
const xnn_datatype data_type_;
std::vector<std::pair<absl::SourceLocation,
std::function<absl::Status(xnn_subgraph_t)>>>
build_steps_;
absl::flat_hash_set<std::shared_ptr<Tensor>> input_tensors_;
absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
// TODO: fix this.
// This is sort of bug that the weights used for rope has to be defined with
// EXTERNAL flag, but with id out of the external range.
absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weigths_;
// Caches
absl::flat_hash_map<
size_t /*dim*/,
absl::flat_hash_map<const Tensor* /*scale*/, std::shared_ptr<Tensor>>>
per_dim_scale_cache_;
};
class XnnGraph {
public:
XnnGraph(XnnSubgraphPtr subgraph,
std::unique_ptr<RuntimeConfigs> runtime_configs)
: owned_subgraph_(std::move(subgraph)),
runtime_configs_(std::move(runtime_configs)) {
DCHECK(runtime_configs_);
}
XnnGraph(XnnGraph&& other) = default;
virtual ~XnnGraph() = default;
// xnn_subgraph should be created with same size.
virtual absl::Status Run();
protected:
friend class XnnGraphBuilder;
absl::Status CreateRuntime();
absl::Status SetupRuntime();
XnnSubgraphPtr owned_subgraph_;
absl::flat_hash_map<size_t, Tensor> avg_cache_;
absl::flat_hash_map<size_t, Tensor> cap_tanh_cache_;
// Runtime
std::unique_ptr<RuntimeConfigs> runtime_configs_;
XnnRuntimePtr runtime_{nullptr, xnn_delete_runtime};
std::vector<xnn_external_value> externals_;
XnnThreadpoolPtr threadpool_{nullptr, pthreadpool_destroy};
absl::flat_hash_set<std::shared_ptr<Tensor>> input_tensors_;
absl::flat_hash_set<std::shared_ptr<Tensor>> output_tensors_;
// TODO: see above
absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weigths_;
absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
};
} // namespace xnn_utils
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_

View File

@ -0,0 +1,475 @@
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h"
#include <cstddef>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/text_generator/calculators/preprocessor_util.h"
#include "mediapipe/tasks/cc/text/text_generator/calculators/sampler_util.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
#include "util/gtl/stl_logging.h"
namespace mediapipe {
namespace xnn_utils {
namespace {
absl::StatusOr<std::shared_ptr<Tensor>> ApplyFinalProj(
std::shared_ptr<Tensor> inter_layer, const UlmWeights& weights,
XnnGraphBuilder& builder) {
return builder.FullConn(inter_layer, weights.softmax_linear,
weights.softmax_bias);
}
} // namespace
class OneTokenUlm : public Ulm {
public:
OneTokenUlm(std::unique_ptr<Ulm> full_ulm, XnnGraph&& other)
: Ulm(std::move(other)), full_ulm_(std::move(full_ulm)) {}
~OneTokenUlm() override = default;
absl::Status InitInputTokens(const std::vector<int>& input_ids) override {
prev_ids_ = input_ids;
MP_RETURN_IF_ERROR(full_ulm_->InitInputTokens(input_ids));
// prev_id.size - 1 is the output.
return full_ulm_->Run();
}
absl::Status GetNextToken(std::vector<int>* output_ids) override {
size_t decode_step = prev_ids_.size() - 1;
VLOG(2) << "Decode step " << decode_step;
if (decode_step == ulm_params_.seq_size_T - 1) {
return absl::OutOfRangeError(
absl::StrCat("Hit max sequence length ", ulm_params_.seq_size_T));
}
transformer_input_->Borrow(
full_ulm_->transformer_input_->Slice(1, decode_step));
atten_masks_->Borrow(full_ulm_->atten_masks_->Slice(0, decode_step));
MP_RETURN_IF_ERROR(segment_pos_->LoadFromBuffer(
full_ulm_->segment_pos_->Slice(0, decode_step)->Data()));
for (auto& kv_cache : kv_cache_) {
DCHECK(kv_cache.k_slice);
DCHECK(kv_cache.v_slice);
kv_cache.k_slice->Borrow(kv_cache.k_cache->Slice(1, decode_step));
kv_cache.v_slice->Borrow(kv_cache.v_cache->Slice(1, decode_step));
}
MP_RETURN_IF_ERROR(SetupRuntime());
MP_RETURN_IF_ERROR(Run());
RET_CHECK(logits_output_);
DCHECK_EQ(logits_output_->num_elements, ulm_params_.voc_size_V);
ASSIGN_OR_RETURN(*output_ids,
mediapipe::SampleNextToken(
logits_output_->DataAs<float>(),
/*batch_size=*/1,
/*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10,
/*top_p=*/1, /*temperature=*/-1));
RET_CHECK_EQ(output_ids->size(), 1);
prev_ids_.push_back(output_ids->at(0));
return GetTokenEmbedding(
*output_ids,
pos_embedding_data_->Slice({decode_step + 1, 0})->DataAs<float>(),
full_ulm_->transformer_input_->Slice({0, decode_step + 1, 0})
->DataAs<float>());
}
private:
std::unique_ptr<Ulm> full_ulm_;
};
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::SelfAttentionExcludeNorm(
std::shared_ptr<Tensor> input, SelfAttentionArgs args,
const SelfAttentionWeights& sa_weights, absl::SourceLocation loc) {
// [B, 1|T, N, H]
ASSIGN_OR_RETURN(auto k_proj, SelfAttentionProj(input, sa_weights.k_weight));
ASSIGN_OR_RETURN(auto q_proj, SelfAttentionProj(input, sa_weights.q_weight));
ASSIGN_OR_RETURN(auto v_proj, SelfAttentionProj(input, sa_weights.v_weight));
ASSIGN_OR_RETURN(auto query_proj_after_rope, Rope(q_proj, args.segment_pos));
ASSIGN_OR_RETURN(auto key_proj_after_rope, Rope(k_proj, args.segment_pos));
if (args.cache) {
RET_CHECK(args.cache->k_cache);
RET_CHECK(args.cache->v_cache);
// When cache is provided, there are 2 cases:
if (*(input->dims.end() - 2) != 1) {
// Building a normal graph, which is used to initialize cache.
key_proj_after_rope->Borrow(args.cache->k_cache).MarkOutput();
v_proj->Borrow(args.cache->v_cache).MarkOutput();
} else {
// Building a one-token graph, which consumes initialized cache.
key_proj_after_rope->MarkOutput();
args.cache->k_slice = key_proj_after_rope;
v_proj->MarkOutput();
args.cache->v_slice = v_proj;
ASSIGN_OR_RETURN(key_proj_after_rope,
NewInput(args.cache->k_cache->dims));
key_proj_after_rope->Borrow(args.cache->k_cache);
ASSIGN_OR_RETURN(v_proj, NewInput(args.cache->v_cache->dims));
v_proj->Borrow(args.cache->v_cache);
}
}
// encoded, [B, 1|T, N, H]
ASSIGN_OR_RETURN(
auto kqv_merged,
DotAttention(query_proj_after_rope, key_proj_after_rope, v_proj,
args.atten_mask, sa_weights.per_dim_scale));
const size_t B = kqv_merged->dims[0];
const size_t T_or_1 = kqv_merged->dims[1];
const size_t NH = kqv_merged->num_elements / (B * T_or_1);
ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, T_or_1, NH}));
return MatMul(outcome_reshaped, sa_weights.post_proj_weight,
{.transpose = false});
}
absl::StatusOr<std::shared_ptr<Tensor>>
UlmBuilder::SelfAttentionIncludeResidual(std::shared_ptr<Tensor> input,
SelfAttentionArgs args,
const SelfAttentionWeights& params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto pre_attention, RmsNorm(input, params.pre_norm));
ASSIGN_OR_RETURN(
auto post_attention,
SelfAttentionExcludeNorm(pre_attention, std::move(args), params));
ASSIGN_OR_RETURN(auto post_norm, RmsNorm(post_attention, params.post_norm));
return ElementAdd(input, post_norm);
}
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::FeedForwardExcludeResidual(
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto first_rms_norm, RmsNorm(input, params.pre_norm));
ASSIGN_OR_RETURN(auto layer_1, FullConn(first_rms_norm, params.layer_1_weight,
params.layer_1_bias));
ASSIGN_OR_RETURN(auto layer_1_gate_before_gelu,
FullConn(first_rms_norm, params.layer_1_gate_weight,
params.layer_1_gate_bias));
ASSIGN_OR_RETURN(auto layer_1_gate, Gelu(layer_1_gate_before_gelu));
ASSIGN_OR_RETURN(auto layer_1_and_gate, ElementMul(layer_1, layer_1_gate));
if (params.opt_padding) {
// activations *= 1.0 - paddings
ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f));
ASSIGN_OR_RETURN(tmp, ElementMul(layer_1_and_gate, tmp));
ASSIGN_OR_RETURN(layer_1_and_gate, ElementAdd(tmp, layer_1_and_gate));
}
ASSIGN_OR_RETURN(
auto layer_2,
FullConn(layer_1_and_gate, params.layer_2_weight, params.layer_2_bias));
if (params.opt_padding) {
// activations *= 1.0 - paddings
ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f));
ASSIGN_OR_RETURN(tmp, ElementMul(layer_2, tmp));
ASSIGN_OR_RETURN(layer_2, ElementAdd(tmp, layer_2));
}
return RmsNorm(layer_2, params.post_norm);
}
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::FeedForwardIncludeResidual(
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
absl::SourceLocation loc) {
ASSIGN_OR_RETURN(auto before_residual,
FeedForwardExcludeResidual(input, params));
return ElementAdd(before_residual, input);
}
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateUlm(
absl::string_view weights_folder, const UlmParams& ulm_params,
std::unique_ptr<RuntimeConfigs> runtime_configs) {
auto weight_loader =
std::make_unique<DefaultUlmWeightsLoader>(weights_folder, ulm_params);
return CreateUlm(std::move(weight_loader), std::move(runtime_configs));
}
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateOneTokenUlm(
std::unique_ptr<UlmWeightsLoader> weight_loader,
std::unique_ptr<RuntimeConfigs> runtime_configs) {
UlmBuilder builder;
// TODO: might be memory waste here, benchmark.
weight_loader->SetBuilder(builder);
ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
UlmParams ulm_params = weight_loader->ulm_params();
ulm_params.enable_kv_cache = true;
weight_loader->ulm_params().enable_kv_cache = true;
weight_loader->ulm_params().final_norm = false;
weight_loader->ulm_params().final_project = false;
ASSIGN_OR_RETURN(auto full_ulm, CreateUlm(std::move(weight_loader)));
ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, 1,
ulm_params.model_dim_D}));
ASSIGN_OR_RETURN(auto atten_masks,
builder.NewInput({1, ulm_params.seq_size_T}));
ASSIGN_OR_RETURN(auto segment_pos,
builder.NewWeight({1, ulm_params.head_dim_H}));
// To allocate buffer before creating runtime.
MP_RETURN_IF_ERROR(segment_pos->LoadFromVec({}, /*exact_match=*/false));
std::vector<KVCache>& kv_cache = full_ulm->kv_cache_;
RET_CHECK_EQ(kv_cache.size(), ulm_params.num_transformer_M);
auto inter_layer = input;
for (int i = 0; i < ulm_params.num_transformer_M; ++i) {
const auto& sa = weights.sas[i];
ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual(
inter_layer,
{.atten_mask = atten_masks,
.segment_pos = segment_pos,
.cache = &kv_cache[i]},
sa));
auto& ff = weights.ffs[i];
// ff.opt_padding = paddings;
ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff));
}
std::shared_ptr<Tensor> logits_output, transformer_output, normed_output;
if (ulm_params.final_norm) {
ASSIGN_OR_RETURN(inter_layer,
builder.RmsNorm(inter_layer, weights.final_ln_scale));
normed_output = inter_layer;
normed_output->MarkOutput();
}
if (ulm_params.final_project) {
RET_CHECK(weights.softmax_linear);
ASSIGN_OR_RETURN(logits_output,
ApplyFinalProj(inter_layer, weights, builder));
logits_output->MarkOutput();
}
ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs)));
Ulm* full_ulm_p = full_ulm.get();
auto result =
std::make_unique<OneTokenUlm>(std::move(full_ulm), std::move(*graph));
{
Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D};
result->pos_embedding_data_ =
std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
result->pos_embedding_data_->Borrow(full_ulm_p->pos_embedding_data_);
}
result->transformer_input_ = input;
result->transformer_output_ = transformer_output;
result->normed_output_ = normed_output;
result->logits_output_ = logits_output;
result->segment_pos_ = segment_pos;
result->atten_masks_ = atten_masks;
if (ulm_params.use_padding) {
// result->paddings_ = paddings;
}
result->kv_cache_ = std::move(kv_cache);
result->weights_ = std::move(weights);
result->ulm_params_ = ulm_params;
return result;
}
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateUlm(
std::unique_ptr<UlmWeightsLoader> weight_loader,
std::unique_ptr<RuntimeConfigs> runtime_configs) {
UlmBuilder builder;
weight_loader->SetBuilder(builder);
const auto& ulm_params = weight_loader->ulm_params();
RET_CHECK_NE(ulm_params.batch_size_B, 0);
ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B,
ulm_params.seq_size_T,
ulm_params.model_dim_D}));
ASSIGN_OR_RETURN(auto atten_masks, builder.NewInput({ulm_params.seq_size_T,
ulm_params.seq_size_T}));
VLOG(1) << "atten mask id " << atten_masks->tensor_id;
ASSIGN_OR_RETURN(
auto segment_pos,
builder.NewWeight({ulm_params.seq_size_T, ulm_params.head_dim_H}));
MP_RETURN_IF_ERROR(FillXnnRoPEWeights(*segment_pos));
VLOG(1) << "segment pos id " << segment_pos->tensor_id;
std::shared_ptr<Tensor> paddings;
if (ulm_params.use_padding) {
ASSIGN_OR_RETURN(paddings, builder.NewInput({ulm_params.batch_size_B,
ulm_params.seq_size_T, 1}));
VLOG(1) << "paddings id " << paddings->tensor_id;
}
ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
std::vector<KVCache> kv_cache;
auto inter_layer = input;
for (int i = 0; i < ulm_params.num_transformer_M; ++i) {
const auto& sa = weights.sas[i];
KVCache* cache = nullptr;
if (ulm_params.enable_kv_cache) {
auto k_cache = std::make_shared<Tensor>(
Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T,
ulm_params.n_heads_N, ulm_params.head_dim_H});
MP_RETURN_IF_ERROR(k_cache->LoadFromVec({}, /*exact_match=*/false));
auto v_cache = std::make_shared<Tensor>(
Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T,
ulm_params.n_heads_N, ulm_params.head_dim_H});
MP_RETURN_IF_ERROR(v_cache->LoadFromVec({}, /*exact_match=*/false));
kv_cache.push_back(KVCache{.k_cache = k_cache, .v_cache = v_cache});
cache = &kv_cache.back();
}
ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual(
inter_layer,
{.atten_mask = atten_masks,
.segment_pos = segment_pos,
.cache = cache},
sa));
auto& ff = weights.ffs[i];
ff.opt_padding = paddings;
ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff));
}
std::shared_ptr<Tensor> logits_output, transformer_output, normed_output;
if (!ulm_params.final_norm && !ulm_params.final_project) {
transformer_output = inter_layer;
transformer_output->MarkOutput();
}
if (ulm_params.final_norm) {
ASSIGN_OR_RETURN(inter_layer,
builder.RmsNorm(inter_layer, weights.final_ln_scale));
normed_output = inter_layer;
normed_output->MarkOutput();
}
if (ulm_params.final_project) {
RET_CHECK(weights.softmax_linear);
ASSIGN_OR_RETURN(logits_output,
ApplyFinalProj(inter_layer, weights, builder));
logits_output->MarkOutput();
}
ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs)));
auto ulm = std::make_unique<Ulm>(std::move(*graph));
{
ASSIGN_OR_RETURN(auto pos_embedding_data,
mediapipe::PositionEmbedding(ulm_params.seq_size_T,
ulm_params.model_dim_D));
Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D};
ulm->pos_embedding_data_ =
std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
MP_RETURN_IF_ERROR(
ulm->pos_embedding_data_->LoadFromVec(pos_embedding_data));
}
ulm->transformer_input_ = input;
ulm->transformer_output_ = transformer_output;
ulm->normed_output_ = normed_output;
ulm->logits_output_ = logits_output;
ulm->segment_pos_ = segment_pos;
ulm->atten_masks_ = atten_masks;
if (ulm_params.use_padding) {
ulm->paddings_ = paddings;
}
ulm->kv_cache_ = std::move(kv_cache);
ulm->weights_ = std::move(weights);
ulm->ulm_params_ = ulm_params;
return ulm;
}
absl::Status Ulm::InitInputTokens(const std::vector<int>& input_ids) {
prev_ids_ = input_ids;
constexpr float neg_value = 0.7 * std::numeric_limits<float>::lowest();
const auto& seq_size = ulm_params_.seq_size_T;
std::vector<float> attention_array(seq_size * seq_size, neg_value);
for (int i = 0; i < seq_size; ++i) {
for (int j = 0; j < seq_size; ++j) {
if (i < input_ids.size() && j < input_ids.size()) {
attention_array[seq_size * i + j] = 0;
} else if (i >= seq_size && j <= i) {
attention_array[seq_size * i + j] = 0;
} else {
break;
}
}
}
MP_RETURN_IF_ERROR(atten_masks_->LoadFromVec(attention_array));
MP_RETURN_IF_ERROR(GetTokenEmbedding(input_ids,
pos_embedding_data_->DataAs<float>(),
transformer_input_->DataAs<float>()));
return SetupRuntime();
}
absl::Status Ulm::GetNextToken(std::vector<int>* output_ids) {
VLOG(2) << "Decode step " << prev_ids_.size() - 1;
MP_RETURN_IF_ERROR(Run());
RET_CHECK(logits_output_);
std::shared_ptr<Tensor> logits =
logits_output_->Slice({0, prev_ids_.size() - 1, 0});
DCHECK_EQ(logits->num_elements, ulm_params_.voc_size_V);
ASSIGN_OR_RETURN(*output_ids,
mediapipe::SampleNextToken(
logits->DataAs<float>(),
/*batch_size=*/1,
/*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10,
/*top_p=*/1, /*temperature=*/-1));
RET_CHECK_EQ(output_ids->size(), 1);
prev_ids_.push_back(output_ids->at(0));
return GetTokenEmbedding(
*output_ids,
pos_embedding_data_->Slice({prev_ids_.size() - 1, 0})->DataAs<float>(),
transformer_input_->Slice({0, prev_ids_.size() - 1, 0})->DataAs<float>());
}
absl::Status Ulm::GetTokenEmbedding(const std::vector<int>& ids,
const float* pos_embedding_data,
float* embedding) {
auto token_embedding = weights_.token_embedding ? weights_.token_embedding
: weights_.softmax_linear;
RET_CHECK(token_embedding->dims[0] == ulm_params_.voc_size_V)
<< "shape must be [vocab_size, _], such that following Slice() makes "
"sense.";
for (size_t id : ids) {
memcpy(embedding, token_embedding->Slice(0, id)->Data(),
ulm_params_.model_dim_D * sizeof(float));
for (size_t i = 0; i < ulm_params_.model_dim_D; ++i) {
embedding[i] += pos_embedding_data[i];
}
pos_embedding_data += ulm_params_.model_dim_D;
embedding += ulm_params_.model_dim_D;
}
return absl::OkStatus();
}
} // namespace xnn_utils
} // namespace mediapipe

View File

@ -0,0 +1,127 @@
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
#include <cstddef>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
namespace mediapipe {
namespace xnn_utils {
class Ulm : public XnnGraph {
public:
using UlmParams = UlmParams;
explicit Ulm(XnnGraph&& other) : XnnGraph(std::move(other)) {}
~Ulm() override = default;
// Creating ULM graph with default params. The default param corresponds to
// ULM1B 256k model.
static absl::StatusOr<std::unique_ptr<Ulm>> CreateUlm(
absl::string_view weights_folder,
const UlmParams& ulm_params =
UlmParams{
.num_transformer_M = 18,
.batch_size_B = 1,
.seq_size_T = 16,
.model_dim_D = 1536,
.hidden_dim_HD = 8 * 1536,
.head_dim_H = 128,
.n_heads_N = 12,
.voc_size_V = 256128,
},
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
static absl::StatusOr<std::unique_ptr<Ulm>> CreateUlm(
std::unique_ptr<UlmWeightsLoader> weight_loader,
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
// Build the graph for one-token inference.
static absl::StatusOr<std::unique_ptr<Ulm>> CreateOneTokenUlm(
std::unique_ptr<UlmWeightsLoader> weight_loader,
std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
// (Re)Initialize with input token ids. This will reset the cache, mask etc.
virtual absl::Status InitInputTokens(const std::vector<int>& input_ids);
// Get the next token id.
virtual absl::Status GetNextToken(std::vector<int>* output_ids);
protected:
friend class OneTokenUlm;
friend class UlmTest;
friend class UlmBuilder;
// Enable if enable_kv_cache
struct KVCache {
std::shared_ptr<Tensor> k_cache;
std::shared_ptr<Tensor> v_cache;
std::shared_ptr<Tensor> k_slice;
std::shared_ptr<Tensor> v_slice;
};
absl::Status GetTokenEmbedding(const std::vector<int>& ids,
const float* pos_embedding_data,
float* embedding);
UlmWeights weights_;
UlmParams ulm_params_;
std::shared_ptr<Tensor> pos_embedding_data_;
std::shared_ptr<Tensor> atten_masks_;
std::shared_ptr<Tensor> segment_pos_;
std::shared_ptr<Tensor> paddings_;
std::shared_ptr<Tensor> transformer_input_;
std::shared_ptr<Tensor> transformer_output_;
std::shared_ptr<Tensor> normed_output_;
std::shared_ptr<Tensor> logits_output_;
// Previous ids, including prompt.
std::vector<int> prev_ids_;
// If enable_kv_cache, expect a mask of [0, ... 0, 1, 0, 0...], size 1 x T.
std::shared_ptr<Tensor> decode_step_mask_;
// [1, 1, ..., 1, 0, 0...], applied on cache
std::shared_ptr<Tensor> decode_step_mask_for_cache_;
std::vector<KVCache> kv_cache_;
};
class UlmBuilder : public XnnGraphBuilder {
public:
struct SelfAttentionArgs {
std::shared_ptr<Tensor> atten_mask;
std::shared_ptr<Tensor> segment_pos;
Ulm::KVCache* cache = nullptr;
};
absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionExcludeNorm(
std::shared_ptr<Tensor> input, SelfAttentionArgs args,
const SelfAttentionWeights& sa_weights,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionIncludeResidual(
std::shared_ptr<Tensor> input, SelfAttentionArgs args,
const SelfAttentionWeights& params,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardExcludeResidual(
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
absl::SourceLocation loc = absl::SourceLocation::current());
absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardIncludeResidual(
std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
absl::SourceLocation loc = absl::SourceLocation::current());
};
} // namespace xnn_utils
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_

View File

@ -0,0 +1,366 @@
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
#include <cmath>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "file/base/filesystem.h"
#include "file/base/options.h"
#include "file/base/path.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
#include "third_party/XNNPACK/include/xnnpack.h"
namespace mediapipe {
namespace xnn_utils {
namespace {
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefixHelper(
XnnGraphBuilder& builder, absl::string_view prefix,
const Tensor::DimsType& dims, size_t dim_scale_if_any) {
RET_CHECK(!prefix.empty() && prefix.back() != '.');
std::vector<std::string> filenames;
auto s = file::Match(absl::StrCat(prefix, "*"), &filenames, file::Defaults());
if (!s.ok()) {
LOG(WARNING) << s;
return nullptr;
} else if (filenames.empty()) {
return nullptr;
}
if (filenames.size() == 1) {
RET_CHECK_EQ(filenames[0], prefix);
return builder.NewWeight(filenames[0], dims);
}
bool is_quantized_tensor = false;
for (const auto& filename : filenames) {
if (absl::StrContains(filename, kQuantizedScaleSuffix)) {
is_quantized_tensor = true;
continue;
}
}
RET_CHECK(is_quantized_tensor)
<< "At least one of {" << filenames << "} must be quantize scale file.";
std::shared_ptr<Tensor> result;
result = std::make_shared<QCTensor>(dims, dim_scale_if_any);
MP_RETURN_IF_ERROR(result->LoadFromFile(prefix));
builder.NewWeight(result);
return result;
}
absl::Status TransposeSelfAttentionWeight(
const UlmWeightsLoader& loader, std::shared_ptr<Tensor>& original_weight,
absl::string_view cache_file_prefix) {
const auto& ulm_param = loader.ulm_params();
RET_CHECK(original_weight);
std::optional<int> from_cache =
original_weight->GetMetadata(UlmWeights::kKeyLoadedFromCache);
if (from_cache && *from_cache) {
return absl::OkStatus();
}
if (auto s = original_weight->DumpToFile(cache_file_prefix); !s.ok()) {
LOG(WARNING) << s;
} else {
MP_RETURN_IF_ERROR(original_weight->LoadFromFile(cache_file_prefix));
}
loader.builder().NewWeight(original_weight);
original_weight->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight,
ulm_param.n_heads_N);
return absl::OkStatus();
}
} // namespace
absl::Status PrepareTokenEmbeddingDecorator::Decorate(
const UlmWeightsLoader& loader, UlmWeights& weight) {
if (weight.token_embedding) {
return absl::OkStatus();
}
const auto& ulm_params = loader.ulm_params();
absl::string_view cache_path = loader.ulm_params().weight_cache_path;
std::string token_embedding_cache_path =
cache_path.empty() ? "" : file::JoinPath(cache_path, "token_embedding.w");
// 1. try cache
if (!token_embedding_cache_path.empty()) {
auto token_embedding =
Tensor::FromFile(token_embedding_cache_path,
{ulm_params.voc_size_V, ulm_params.model_dim_D});
if (token_embedding.ok()) {
weight.token_embedding = *token_embedding;
return absl::OkStatus();
}
}
// 2. fill embedding from softmax_linear
auto& softmax_linear = *weight.softmax_linear;
RET_CHECK(softmax_linear.dims[0] == ulm_params.voc_size_V) << softmax_linear;
if (softmax_linear.datatype == xnn_datatype_fp32) {
weight.token_embedding = softmax_linear.View();
} else if (softmax_linear.datatype == xnn_datatype_qcint8) {
ASSIGN_OR_RETURN(weight.token_embedding, softmax_linear.ConvertToF32());
}
float* embedding_data = weight.token_embedding->DataAs<float>();
for (size_t i = 0; i < softmax_linear.num_elements; ++i) {
embedding_data[i] *= std::sqrt(loader.ulm_params().model_dim_D);
}
// 3. save cache
if (!token_embedding_cache_path.empty()) {
MP_RETURN_IF_ERROR(
weight.token_embedding->DumpToFile(token_embedding_cache_path));
return weight.token_embedding->LoadFromFile(token_embedding_cache_path);
}
return absl::OkStatus();
}
absl::Status TransposeSelfAttentionWeightDecorator::Decorate(
const UlmWeightsLoader& loader, UlmWeights& weight) {
absl::string_view cache_path = loader.ulm_params().weight_cache_path;
if (cache_path.empty()) {
return absl::OkStatus();
}
for (size_t i = 0; i < weight.sas.size(); ++i) {
auto& sa = weight.sas[i];
auto prefix = absl::StrCat(UlmWeightsLoader::kTransformerWeightPrefix, i,
".self_attention.");
MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
loader, sa.k_weight,
file::JoinPath(cache_path, absl::StrCat(prefix, "k.w"))));
MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
loader, sa.q_weight,
file::JoinPath(cache_path, absl::StrCat(prefix, "q.w"))));
MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
loader, sa.v_weight,
file::JoinPath(cache_path, absl::StrCat(prefix, "v.w"))));
}
return absl::OkStatus();
}
absl::StatusOr<std::shared_ptr<Tensor>> UlmWeightsLoader::LoadFromAbsPathPrefix(
absl::string_view prefix, const Tensor::DimsType& dims,
size_t dim_scale_if_any) const {
return LoadFromAbsPathPrefixHelper(*builder_, prefix, dims, dim_scale_if_any);
}
absl::StatusOr<std::shared_ptr<Tensor>>
UlmWeightsLoader::TryCacheThenLoadSelfAttention(
absl::string_view filename_prefix) const {
ASSIGN_OR_RETURN(
auto r,
TryCacheThenLoadWeightTranspose(
filename_prefix,
{params_.model_dim_D, params_.n_heads_N * params_.head_dim_H}, 1));
r->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight,
params_.n_heads_N);
return r;
}
absl::StatusOr<std::shared_ptr<Tensor>>
UlmWeightsLoader::TryCacheThenLoadFeedForward(
absl::string_view filename_prefix,
std::optional<Tensor::DimsType> dims) const {
if (!dims) {
dims = {params_.model_dim_D, params_.hidden_dim_HD};
}
return TryCacheThenLoadWeightTranspose(filename_prefix, *dims, 1);
}
absl::StatusOr<std::shared_ptr<Tensor>>
UlmWeightsLoader::TryCacheThenLoadWeightTranspose(
absl::string_view filename_prefix, Tensor::DimsType original_dims,
size_t original_dim_cale) const {
if (!params_.weight_cache_path.empty()) {
auto cache_full_prefix =
file::JoinPath(params_.weight_cache_path, filename_prefix);
Tensor::DimsType cache_dim{original_dims.rbegin(), original_dims.rend()};
ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix(
cache_full_prefix, std::move(cache_dim),
/*dim_scale_if_any=*/1 - original_dim_cale));
if (r) {
r->SetMetadata(UlmWeights::kKeyLoadedFromCache, 1);
return r;
}
}
ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix(
file::JoinPath(weight_path_, filename_prefix),
std::move(original_dims),
/*dim_scale_if_any=*/original_dim_cale));
RET_CHECK(r) << file::JoinPath(weight_path_, filename_prefix);
r = r->Transpose();
builder_->NewWeight(r);
return r;
}
absl::StatusOr<FeedForwardWeights> UlmWeightsLoader::LoadFeedForward(
int layer_id) {
absl::string_view weights_folder = weight_path_;
const auto& params = params_;
auto ff_file_prefix =
absl::StrCat(kTransformerWeightPrefix, layer_id, ".ff_layer.");
auto ff_prefix = file::JoinPath(weights_folder, ff_file_prefix);
FeedForwardWeights feed_forward;
ASSIGN_OR_RETURN(
feed_forward.pre_norm,
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "pre_layer_norm.scale"),
{params.model_dim_D}));
ASSIGN_OR_RETURN(
feed_forward.post_norm,
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "post_layer_norm.scale"),
{params.model_dim_D}));
ASSIGN_OR_RETURN(
feed_forward.layer_1_bias,
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1.bias.b"),
{params.hidden_dim_HD}));
ASSIGN_OR_RETURN(feed_forward.layer_1_weight,
TryCacheThenLoadFeedForward(
absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w")));
ASSIGN_OR_RETURN(
feed_forward.layer_1_gate_bias,
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1_gate.bias.b"),
{params.hidden_dim_HD}));
ASSIGN_OR_RETURN(feed_forward.layer_1_gate_weight,
TryCacheThenLoadFeedForward(absl::StrCat(
ff_file_prefix, "ffn_layer1_gate.linear.w")));
ASSIGN_OR_RETURN(
feed_forward.layer_2_bias,
LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer2.bias.b"),
{params.model_dim_D}, /*dim_scale_if_any=*/0));
ASSIGN_OR_RETURN(
feed_forward.layer_2_weight,
TryCacheThenLoadFeedForward(
absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"),
Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D}));
return feed_forward;
}
absl::StatusOr<SelfAttentionWeights> UlmWeightsLoader::LoadSelfAttention(
int layer_id) {
absl::string_view weights_folder = weight_path_;
const auto& params = params_;
SelfAttentionWeights self_attention;
auto sa_file_prefix = absl::StrCat(kTransformerWeightPrefix, layer_id);
auto sa_prefix = file::JoinPath(weights_folder, sa_file_prefix);
ASSIGN_OR_RETURN(
self_attention.pre_norm,
LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".pre_layer_norm.scale"),
{params.model_dim_D}));
ASSIGN_OR_RETURN(
self_attention.post_norm,
LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".post_layer_norm.scale"),
{params.model_dim_D}));
absl::StrAppend(&sa_file_prefix, ".self_attention.");
ASSIGN_OR_RETURN(
self_attention.k_weight,
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w")));
ASSIGN_OR_RETURN(
self_attention.q_weight,
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w")));
ASSIGN_OR_RETURN(
self_attention.v_weight,
TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w")));
sa_prefix = file::JoinPath(weights_folder, sa_file_prefix);
ASSIGN_OR_RETURN(self_attention.per_dim_scale,
LoadFromAbsPathPrefix(
absl::StrCat(sa_prefix, "per_dim_scale.per_dim_scale"),
{params.head_dim_H}));
ASSIGN_OR_RETURN(self_attention.post_proj_weight,
LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, "post.w"),
{params.model_dim_D,
params.n_heads_N * params.head_dim_H},
/*dim_scale_if_any=*/0));
return self_attention;
}
absl::StatusOr<UlmWeights> UlmWeightsLoader::LoadWeights() {
absl::string_view weights_folder = weight_path_;
const auto& params = params_;
UlmWeights result;
for (int layer_id = 0; layer_id < params.num_transformer_M; ++layer_id) {
ASSIGN_OR_RETURN(auto ff, LoadFeedForward(layer_id));
result.ffs.push_back(std::move(ff));
ASSIGN_OR_RETURN(auto sa, LoadSelfAttention(layer_id));
result.sas.push_back(std::move(sa));
}
if (params.final_norm) {
ASSIGN_OR_RETURN(result.final_ln_scale,
LoadFromAbsPathPrefix(
file::JoinPath(weights_folder, kFinalScaleFilename),
{params.model_dim_D}));
}
ASSIGN_OR_RETURN(result.softmax_bias,
LoadFromAbsPathPrefix(
file::JoinPath(weights_folder, kLogitsFfnBiasFilename),
{params.voc_size_V}));
ASSIGN_OR_RETURN(result.softmax_linear,
TryCacheThenLoadWeightTranspose(
kLogitsFfnWeightFilename,
{params.model_dim_D, params.voc_size_V}, 1));
return result;
}
BenchmarkUlmWeightsLoader::BenchmarkUlmWeightsLoader(const UlmParams& params,
xnn_datatype data_type)
: DefaultUlmWeightsLoader("", params), data_type_(data_type) {
params_.weight_cache_path.clear();
}
absl::StatusOr<std::shared_ptr<Tensor>>
BenchmarkUlmWeightsLoader::TryCacheThenLoadWeightTranspose(
absl::string_view filename_prefix, Tensor::DimsType original_dims,
size_t original_dim_cale) const {
auto result = std::make_shared<QCTensor>(
Tensor::DimsType{original_dims.rbegin(), original_dims.rend()},
1 - original_dim_cale);
auto real_data = std::make_shared<std::string>(result->num_elements, 0xA5);
result->flat_data = std::shared_ptr<char>(real_data, real_data->data());
auto real_scale = std::make_shared<std::vector<float>>(
original_dims[original_dim_cale], 1.0f);
result->scale_data = std::shared_ptr<float>(real_scale, real_scale->data());
builder_->NewWeight(result);
return result;
}
absl::StatusOr<std::shared_ptr<Tensor>>
BenchmarkUlmWeightsLoader::LoadFromAbsPathPrefix(
absl::string_view prefix, const Tensor::DimsType& dims,
size_t dim_scale_if_any) const {
// If loader calls this function directly, it's always non-quantized weights.
auto result = std::make_shared<Tensor>(dims);
MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
builder_->NewWeight(result);
return result;
}
} // namespace xnn_utils
} // namespace mediapipe

View File

@ -0,0 +1,192 @@
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
#include "third_party/XNNPACK/include/xnnpack.h"
namespace mediapipe {
namespace xnn_utils {
struct UlmParams {
size_t num_transformer_M = 18;
size_t batch_size_B = 1;
size_t seq_size_T = 16;
size_t model_dim_D = 1536;
size_t hidden_dim_HD = 8 * 1536;
size_t head_dim_H = 128;
size_t n_heads_N = 12;
size_t voc_size_V = 32000;
bool use_padding = true;
bool final_norm = true;
bool final_project = true;
bool enable_kv_cache = false;
// Path to store reshaped weights as cache. Set empty to disable caching.
std::string weight_cache_path;
};
struct SelfAttentionWeights {
std::shared_ptr<Tensor> pre_norm;
std::shared_ptr<Tensor> k_weight;
std::shared_ptr<Tensor> q_weight;
std::shared_ptr<Tensor> v_weight;
std::shared_ptr<Tensor> per_dim_scale;
std::shared_ptr<Tensor> post_proj_weight;
std::shared_ptr<Tensor> post_norm;
};
struct FeedForwardWeights {
std::shared_ptr<Tensor> pre_norm;
std::shared_ptr<Tensor> layer_1_weight;
std::shared_ptr<Tensor> layer_1_bias;
std::shared_ptr<Tensor> layer_1_gate_weight;
std::shared_ptr<Tensor> layer_1_gate_bias;
std::shared_ptr<Tensor> layer_2_weight;
std::shared_ptr<Tensor> layer_2_bias;
std::shared_ptr<Tensor> post_norm;
std::shared_ptr<Tensor> opt_padding;
};
struct UlmWeights {
std::vector<FeedForwardWeights> ffs;
std::vector<SelfAttentionWeights> sas;
std::shared_ptr<Tensor> final_ln_scale;
std::shared_ptr<Tensor> softmax_linear;
std::shared_ptr<Tensor> softmax_bias;
// Optional. Usually softmax_linear can be used as embedding, but sometimes we
// need to scale/transpose it.
std::shared_ptr<Tensor> token_embedding;
static constexpr absl::string_view kKeyLoadedFromCache{"loaded_from_cache"};
};
class UlmWeightsLoader {
public:
constexpr static absl::string_view kTransformerWeightPrefix{
"params.lm.transformer.x_layers_"};
constexpr static absl::string_view kFinalScaleFilename{
"params.lm.final_ln.scale"};
constexpr static absl::string_view kLogitsFfnBiasFilename{
"params.lm.softmax.logits_ffn.bias.b"};
constexpr static absl::string_view kLogitsFfnWeightFilename{
"params.lm.softmax.logits_ffn.linear.w"};
UlmWeightsLoader(absl::string_view weight_path, const UlmParams& params)
: weight_path_(weight_path), params_(params) {}
virtual ~UlmWeightsLoader() = default;
void SetBuilder(XnnGraphBuilder& builder) { builder_ = &builder; }
virtual absl::StatusOr<UlmWeights> LoadWeights();
virtual absl::StatusOr<SelfAttentionWeights> LoadSelfAttention(int layer_id);
virtual absl::StatusOr<FeedForwardWeights> LoadFeedForward(int layer_id);
UlmParams& ulm_params() { return params_; }
const UlmParams& ulm_params() const { return params_; }
XnnGraphBuilder& builder() const { return *builder_; }
protected:
// Find the files that matches prefix, then read from file.
virtual absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
absl::string_view prefix, const Tensor::DimsType& dims,
size_t dim_scale_if_any) const;
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
absl::string_view prefix, const Tensor::DimsType& dims) const {
return LoadFromAbsPathPrefix(prefix, dims, 0);
}
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadSelfAttention(
absl::string_view filename_prefix) const;
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadFeedForward(
absl::string_view filename_prefix,
std::optional<Tensor::DimsType> dims = std::nullopt) const;
virtual absl::StatusOr<std::shared_ptr<Tensor>>
TryCacheThenLoadWeightTranspose(absl::string_view filename_prefix,
Tensor::DimsType original_dims,
size_t original_dim_cale) const;
std::string weight_path_;
UlmParams params_;
XnnGraphBuilder* builder_ = nullptr;
};
// Try: 1. load token embedding from cache; 2. fill token embedding by transpose
// softmax linear then scale; 3. dump token embedding to cache.
struct PrepareTokenEmbeddingDecorator {
static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
};
struct TransposeSoftmaxWeightDecorator {
static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
};
struct TransposeSelfAttentionWeightDecorator {
// If KQV weight are reshaped, ignore.
// If KQV weight are not properly shaped, load from cache if any, or build.
// If KQV weight are missing, try loading from cache path, or fail if missing.
static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
};
// Apply some decoration (in order) to the weights loaded by base class.
template <class... Decorators>
class UlmWeightsLoaderWith : public UlmWeightsLoader {
public:
UlmWeightsLoaderWith(absl::string_view weight_path, const UlmParams& params)
: UlmWeightsLoader(weight_path, params),
decorators_{Decorators::Decorate...} {}
absl::StatusOr<UlmWeights> LoadWeights() override {
ASSIGN_OR_RETURN(auto result, UlmWeightsLoader::LoadWeights());
for (const auto& decorator : decorators_) {
MP_RETURN_IF_ERROR(decorator(*this, result));
}
return result;
}
protected:
std::vector<std::function<absl::Status(const UlmWeightsLoader&, UlmWeights&)>>
decorators_;
};
using DefaultUlmWeightsLoader =
UlmWeightsLoaderWith<TransposeSelfAttentionWeightDecorator,
PrepareTokenEmbeddingDecorator>;
// Generate weights with some random value.
class BenchmarkUlmWeightsLoader : public DefaultUlmWeightsLoader {
public:
explicit BenchmarkUlmWeightsLoader(
const UlmParams& params, xnn_datatype data_type = xnn_datatype_fp32);
absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadWeightTranspose(
absl::string_view filename_prefix, Tensor::DimsType original_dims,
size_t original_dim_cale) const override;
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
absl::string_view prefix, const Tensor::DimsType& dims,
size_t dim_scale_if_any) const override;
private:
xnn_datatype data_type_;
std::shared_ptr<Tensor> random_value_buffer_;
};
} // namespace xnn_utils
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_

View File

@ -0,0 +1,21 @@
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
namespace mediapipe {
namespace xnn_utils {
std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels) {
std::vector<float> out_array(max_seq_len * num_channels);
for (size_t ch_id = 0; ch_id < num_channels / 2; ++ch_id) {
auto timescale = std::pow(1e-4, 2.0 * ch_id / num_channels);
for (size_t seq_id = 0; seq_id < max_seq_len; ++seq_id) {
auto sinusoid_inp = seq_id * timescale;
out_array[seq_id * num_channels + ch_id] = cos(sinusoid_inp);
out_array[seq_id * num_channels + ch_id + num_channels / 2] =
sin(sinusoid_inp);
}
}
return out_array;
}
} // namespace xnn_utils
} // namespace mediapipe

View File

@ -0,0 +1,61 @@
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
#include <fcntl.h>
#include <sys/mman.h>
#include "absl/cleanup/cleanup.h"
#include "absl/status/statusor.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace xnn_utils {
std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels);
// expect_size_bytes == 0 means don't check size.
template <typename element_type = char>
static absl::StatusOr<std::shared_ptr<element_type>> LoadBufferFromFile(
absl::string_view file_path, bool use_mmap = true,
size_t expect_size_bytes = 0) {
if (use_mmap) {
int fd = open(file_path.data(), O_RDONLY);
RET_CHECK_GE(fd, 0) << "open " << file_path << " failed";
auto cleanup = absl::MakeCleanup([fd] { close(fd); });
const size_t size = lseek(fd, 0, SEEK_END);
if (expect_size_bytes) {
RET_CHECK_EQ(expect_size_bytes, size)
<< "File size " << size << ", expected " << expect_size_bytes
<< ", file path " << file_path;
}
void* data = mmap(/*addr=*/nullptr, size, /*prot=*/PROT_READ,
/*flags=*/MAP_SHARED, fd, /*offset=*/0);
RET_CHECK_NE(data, MAP_FAILED);
RET_CHECK_NE(data, nullptr);
return std::shared_ptr<element_type>(static_cast<element_type*>(data),
[](auto* p) {});
} else {
auto read_buffer = std::make_shared<std::string>();
MP_RETURN_IF_ERROR(
file::GetContents(file_path, read_buffer.get(), file::Defaults()));
if (expect_size_bytes) {
RET_CHECK_EQ(expect_size_bytes, read_buffer->size())
<< "File size " << read_buffer->size() << ", expected "
<< expect_size_bytes << ", file path " << file_path;
}
return std::shared_ptr<element_type>(
read_buffer, reinterpret_cast<element_type*>(read_buffer->data()));
}
}
} // namespace xnn_utils
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_

View File

@ -0,0 +1,358 @@
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>
#include <cstddef>
#include <cstring>
#include <memory>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
#include "third_party/XNNPACK/include/xnnpack.h"
namespace mediapipe {
namespace xnn_utils {
absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos) {
RET_CHECK_EQ(out_seg_pos.dims.size(), 2);
const size_t max_seq_len = out_seg_pos.dims[0];
const size_t num_channels = out_seg_pos.dims[1];
return out_seg_pos.LoadFromVec(FillXnnRoPEWeights(max_seq_len, num_channels));
}
std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
os << "Tensor{dims=[" << tensor.dims << "], datatype=" << tensor.datatype
<< ", num_elements=" << tensor.num_elements << "}";
return os;
}
std::ostream& operator<<(std::ostream& os, const QCTensor& tensor) {
os << "QCTensor{dims=[" << tensor.dims << "], dim_scale=" << tensor.dim_scale
<< " datatype=" << tensor.datatype
<< ", num_elements=" << tensor.num_elements << "}";
return os;
}
bool Tensor::operator==(const Tensor& other) const {
if (dims.size() != other.dims.size()) {
return false;
} else if (datatype != other.datatype) {
return false;
} else {
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] != other.dims[i]) {
return false;
}
}
}
return 0 == memcmp(Data(), other.Data(), num_elements * ElementSize());
}
void Tensor::AllocateBufferIfNeeded() {
if (!flat_data) {
auto real_buffer = std::make_shared<std::string>();
real_buffer->reserve(num_elements * ElementSize() + XNN_EXTRA_BYTES);
flat_data = std::shared_ptr<char>(real_buffer, real_buffer->data());
}
}
void* Tensor::Data() {
DCHECK(flat_data)
<< "If this is weight, you may need to call one of the LoadFrom*()";
return flat_data.get();
}
std::shared_ptr<Tensor> Tensor::Slice(DimsType offset) {
DCHECK(flat_data);
CHECK_EQ(offset.size(), dims.size()) << offset << " vs. " << dims;
// offset: [0, k, 0, 0], dims: [1, K, _, _]. dims before k must be 1.
bool found_non_zero_offset = false;
int index_k = -1;
for (int i = 0; i < dims.size(); ++i) {
if (found_non_zero_offset) {
DCHECK_EQ(offset[i], 0);
} else if (offset[i] != 0) {
found_non_zero_offset = true;
index_k = i;
}
}
DCHECK(found_non_zero_offset) << offset;
return Slice(index_k, offset[index_k]);
}
std::shared_ptr<Tensor> Tensor::Slice(size_t index, size_t offset) {
size_t num_elements_offset = 1;
DimsType new_dim = dims;
for (int i = 0; i < dims.size(); ++i) {
if (i < index) {
DCHECK_EQ(dims[i], 1);
} else if (i == index) {
num_elements_offset *= offset;
new_dim[i] = 1;
} else {
num_elements_offset *= dims[i];
}
}
auto result = std::make_shared<Tensor>(std::move(new_dim), datatype);
result->flat_data = std::shared_ptr<char>(
flat_data, flat_data.get() + num_elements_offset * ElementSize());
return result;
}
Tensor& Tensor::Borrow(std::shared_ptr<Tensor> other, size_t element_offset) {
DCHECK_EQ(datatype, other->datatype);
DCHECK_EQ(dims.size(), other->dims.size());
flat_data = std::shared_ptr<char>(
other->flat_data,
other->flat_data.get() + element_offset * ElementSize());
return *this;
}
std::shared_ptr<Tensor> Tensor::View() { return View(dims); }
std::shared_ptr<Tensor> Tensor::View(DimsType as_dims, size_t) {
auto result = std::make_shared<Tensor>(as_dims, datatype);
DCHECK_LE(result->num_elements, num_elements);
result->flat_data = flat_data;
return result;
}
const void* Tensor::Data() const { return const_cast<Tensor*>(this)->Data(); }
absl::Status Tensor::DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags) {
uint32_t id;
RET_CHECK_EQ(xnn_status_success,
xnn_define_tensor_value(&subgraph, datatype, dims.size(),
dims.data(), /*data=*/nullptr,
/*external_id=*/tensor_id, flags, &id));
if (tensor_id == XNN_INVALID_VALUE_ID) {
RET_CHECK_NE(id, XNN_INVALID_VALUE_ID);
tensor_id = id;
} else {
RET_CHECK_EQ(id, tensor_id);
}
return absl::OkStatus();
}
absl::Status Tensor::DefineAsInput(xnn_subgraph& subgraph) {
return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
}
absl::Status Tensor::DefineAsOutput(xnn_subgraph& subgraph) {
return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_OUTPUT);
}
absl::Status Tensor::DefineAsIntermediateTensor(xnn_subgraph& subgraph) {
RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID);
return DefineAsExternal(subgraph, 0);
}
absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
RET_CHECK_EQ(
xnn_status_success,
xnn_define_tensor_value(&subgraph, datatype, dims.size(), dims.data(),
Data(), tensor_id, flags, &tensor_id));
RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
return absl::OkStatus();
}
absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph) {
RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID);
return DefineWeight(subgraph, 0);
}
absl::Status Tensor::DefineRope(xnn_subgraph& subgraph) {
RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
return DefineWeight(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
}
absl::Status Tensor::LoadFromBuffer(const void* buffer) {
AllocateBufferIfNeeded();
memcpy(Data(), buffer, num_elements * ElementSize());
return absl::OkStatus();
}
absl::Status Tensor::LoadFromVec(const std::vector<float>& data,
bool exact_match) {
AllocateBufferIfNeeded();
if (exact_match) {
RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float));
}
memcpy(Data(), data.data(), data.size() * sizeof(float));
return absl::OkStatus();
}
absl::Status Tensor::LoadFromVec(std::vector<float>&& data, bool exact_match) {
if (exact_match) {
RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float));
}
auto real_buffer = std::make_shared<std::vector<float>>(std::move(data));
if (real_buffer->size() < num_elements) {
real_buffer->resize(num_elements);
}
flat_data = std::shared_ptr<char>(
real_buffer, reinterpret_cast<char*>(real_buffer->data()));
return absl::OkStatus();
}
absl::Status Tensor::DumpToBuffer(void* buffer) {
memcpy(buffer, Data(), num_elements * ElementSize());
return absl::OkStatus();
}
absl::Status Tensor::DumpToVec(std::vector<float>& out_data, bool exact_match) {
if (exact_match) {
RET_CHECK_EQ(num_elements * ElementSize(), out_data.size() * sizeof(float));
} else {
out_data.resize(num_elements);
}
memcpy(out_data.data(), Data(), num_elements * ElementSize());
return absl::OkStatus();
}
absl::Status Tensor::DumpToFile(absl::string_view file_path) {
return file::SetContents(
file_path,
absl::string_view(flat_data.get(), num_elements * ElementSize()),
file::Defaults());
}
absl::Status Tensor::LoadFromFile(absl::string_view file_path, bool use_mmap,
bool exact_match) {
const size_t expected_size_in_bytes =
exact_match ? num_elements * ElementSize() : 0;
ASSIGN_OR_RETURN(flat_data, LoadBufferFromFile(file_path, use_mmap,
expected_size_in_bytes));
return absl::OkStatus();
}
std::shared_ptr<Tensor> Tensor::Transpose() {
DCHECK_EQ(dims.size(), 2);
DimsType out_dims{dims.rbegin(), dims.rend()};
auto result = std::make_shared<Tensor>(std::move(out_dims), datatype);
result->AllocateBufferIfNeeded();
xnn_status s;
const DimsType perm{1, 0};
if (datatype == xnn_datatype_fp32) {
s = xnn_run_transpose_nd_x32(Data(), result->Data(), dims.size(),
dims.data(), perm.data(),
/*flags=*/0, /*threadpool=*/nullptr);
} else {
LOG(FATAL) << "Need update to support new type";
}
DCHECK_EQ(s, xnn_status_success);
return (s == xnn_status_success) ? result : nullptr;
}
absl::StatusOr<std::shared_ptr<Tensor>> Tensor::ConvertToF32() {
auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32);
MP_RETURN_IF_ERROR(result->LoadFromBuffer(Data()));
return result;
}
absl::Status QCTensor::LoadFromFile(absl::string_view quantized_weight_filename,
absl::string_view scale_filename,
bool use_mmap, bool exact_match) {
size_t scale_element_size = dims[dim_scale];
ASSIGN_OR_RETURN(flat_data,
LoadBufferFromFile(quantized_weight_filename, use_mmap,
exact_match ? num_elements : 0));
ASSIGN_OR_RETURN(scale_data,
LoadBufferFromFile<float>(
scale_filename, use_mmap,
exact_match ? scale_element_size * sizeof(float) : 0));
return absl::OkStatus();
}
absl::Status QCTensor::DumpToFile(absl::string_view file_path) {
MP_RETURN_IF_ERROR(file::SetContents(
file_path,
absl::string_view(flat_data.get(), num_elements * ElementSize()),
file::Defaults()));
return file::SetContents(
absl::StrCat(file_path, kQuantizedScaleSuffix),
absl::string_view(reinterpret_cast<char*>(scale_data.get()),
dims[dim_scale] * sizeof(float)),
file::Defaults());
}
absl::Status QCTensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
RET_CHECK_EQ(
xnn_status_success,
xnn_define_channelwise_quantized_tensor_value(
&subgraph, datatype, scale_data.get(), dims.size(), dim_scale,
dims.data(), Data(), XNN_INVALID_VALUE_ID, flags, &tensor_id))
<< *this;
RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
return absl::OkStatus();
}
void QCTensor::AllocateBufferIfNeeded() {
Tensor::AllocateBufferIfNeeded();
if (!scale_data) {
auto real_buffer = std::make_shared<std::vector<float>>();
real_buffer->reserve(dims[dim_scale]);
scale_data = std::shared_ptr<float>(real_buffer, real_buffer->data());
}
}
std::shared_ptr<Tensor> QCTensor::Transpose() {
DCHECK_EQ(dims.size(), 2);
size_t channel_size = dims[dim_scale];
DimsType out_dims{dims.rbegin(), dims.rend()};
auto result = std::make_shared<QCTensor>(std::move(out_dims), 1 - dim_scale);
result->AllocateBufferIfNeeded();
memcpy(result->scale_data.get(), scale_data.get(),
channel_size * sizeof(float));
xnn_status s;
const DimsType perm{1, 0};
if (datatype == xnn_datatype_qcint8) {
s = xnn_run_transpose_nd_x8(Data(), result->Data(), dims.size(),
dims.data(), perm.data(),
/*flags=*/0, /*threadpool=*/nullptr);
} else {
LOG(FATAL) << "Need update to support new type";
}
DCHECK_EQ(s, xnn_status_success);
return (s == xnn_status_success) ? result : nullptr;
}
absl::StatusOr<std::shared_ptr<Tensor>> QCTensor::ConvertToF32() {
auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32);
// TODO: proper implement.
LOG(WARNING) << "This is fake impl";
MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
return result;
}
std::shared_ptr<Tensor> QCTensor::View(DimsType as_dims,
size_t dim_scale_if_any) {
auto result = std::make_shared<QCTensor>(as_dims, dim_scale_if_any);
DCHECK_LE(result->num_elements, num_elements);
result->flat_data = flat_data;
result->scale_data = scale_data;
return result;
}
} // namespace xnn_utils
} // namespace mediapipe

View File

@ -0,0 +1,202 @@
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
#include <fcntl.h>
#include <sys/mman.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "file/base/helpers.h"
#include "file/base/options.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
#include "third_party/XNNPACK/include/xnnpack.h"
#include "util/gtl/stl_logging.h"
namespace mediapipe {
namespace xnn_utils {
static constexpr absl::string_view kQuantizedScaleSuffix{"_quantized_scale"};
static constexpr absl::string_view kSparsityParamsSuffix{"_sparsity_params"};
struct Tensor {
using DimsType = std::vector<size_t>;
explicit Tensor(DimsType in_dims, xnn_datatype datatype_ = xnn_datatype_fp32)
: dims(std::move(in_dims)),
num_elements(dims.empty() ? 0
: std::accumulate(std::begin(dims),
std::end(dims), size_t(1),
std::multiplies<size_t>())),
datatype(datatype_) {}
Tensor(Tensor&& other) = default;
Tensor& operator=(const Tensor& other) = delete;
Tensor& operator=(Tensor&& other) = default;
virtual ~Tensor() = default;
bool operator==(const Tensor& other) const;
void SetMetadata(absl::string_view key, int value) { metadata[key] = value; }
std::optional<int> GetMetadata(absl::string_view key) const {
if (metadata.contains(key)) {
return metadata.at(key);
}
return std::nullopt;
}
// Read weights from file.
template <xnn_datatype xnn_datatype_ = xnn_datatype_fp32>
static absl::StatusOr<std::shared_ptr<Tensor>> FromFile(
absl::string_view file_path, DimsType dims, bool use_mmap = true) {
auto result = std::make_shared<Tensor>(std::move(dims), xnn_datatype_);
MP_RETURN_IF_ERROR(
result->LoadFromFile(file_path, use_mmap, /*exact_match=*/true));
return result;
}
virtual absl::Status DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags);
absl::Status DefineAsInput(xnn_subgraph& subgraph);
absl::Status DefineAsOutput(xnn_subgraph& subgraph);
absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph);
virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags);
absl::Status DefineWeight(xnn_subgraph& subgraph);
absl::Status DefineRope(xnn_subgraph& subgraph);
absl::Status LoadFromBuffer(const void* buffer);
absl::Status LoadFromVec(const std::vector<float>& data,
bool exact_match = true);
absl::Status LoadFromVec(std::vector<float>&& data, bool exact_match = true);
absl::Status LoadFromFile(absl::string_view file_path) {
return LoadFromFile(file_path, true, true);
}
virtual absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
bool exact_match);
absl::Status DumpToBuffer(void* buffer);
absl::Status DumpToVec(std::vector<float>& out_data, bool exact_match = true);
virtual absl::Status DumpToFile(absl::string_view file_path);
// If ith offset is 0, view's ith dim equals to original ith dim, otherwise 1.
std::shared_ptr<Tensor> Slice(DimsType offset);
// Slice along the `index`th dimension, offset at this dimension.
std::shared_ptr<Tensor> Slice(size_t index, size_t offset);
// Point the underline data to the borrowed tensor's data.
Tensor& Borrow(std::shared_ptr<Tensor>, size_t element_offset = 0);
std::shared_ptr<Tensor> View();
virtual std::shared_ptr<Tensor> View(DimsType as_dims,
size_t dim_scale_if_any = 0);
Tensor& MarkOutput() {
AllocateBufferIfNeeded();
is_output_tensor = true;
return *this;
}
virtual void* Data();
const void* Data() const;
template <typename T>
T* DataAs() {
DCHECK_EQ(ElementSize(), sizeof(T));
return static_cast<T*>(Data());
}
template <typename T>
const T* DataAs() const {
return static_cast<const T*>(Data());
}
virtual std::shared_ptr<Tensor> Transpose();
virtual absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32();
DimsType dims;
size_t num_elements = 0;
xnn_datatype datatype = xnn_datatype_invalid;
uint32_t tensor_id = XNN_INVALID_VALUE_ID;
// shared_ptr to make TensorMetadata copyable.
std::shared_ptr<char> flat_data;
protected:
friend class XnnGraphBuilder;
friend class XnnGraph;
// Actually allocate buffer unless necessary.
virtual void AllocateBufferIfNeeded();
virtual size_t ElementSize() const { return 4; }
bool is_output_tensor = false;
absl::flat_hash_map<std::string, int> metadata;
};
std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
// Channelwise Quantized.
struct QCTensor : public Tensor {
explicit QCTensor(DimsType in_dims, size_t dim_scale_if_any)
: Tensor(std::move(in_dims)), dim_scale(dim_scale_if_any) {
datatype = xnn_datatype_qcint8;
CHECK_LT(dim_scale, 4);
}
void AllocateBufferIfNeeded() override;
size_t ElementSize() const override { return 1; }
virtual absl::Status LoadFromFile(absl::string_view quantized_weight_filename,
absl::string_view scale_filename,
bool use_mmap, bool exact_match);
// Append kQuantizedScaleSuffix to use as scale filename.
absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
bool exact_match) override {
return LoadFromFile(file_path,
absl::StrCat(file_path, kQuantizedScaleSuffix),
use_mmap, exact_match);
}
absl::Status DumpToFile(absl::string_view file_path) override;
absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags) override;
std::shared_ptr<Tensor> Transpose() override;
absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32() override;
std::shared_ptr<Tensor> View(DimsType as_dims,
size_t dim_scale_if_any) override;
std::shared_ptr<float> scale_data;
// Index of the dimension to scale.
size_t dim_scale;
};
std::ostream& operator<<(std::ostream& os, const QCTensor& tensor);
absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos);
} // namespace xnn_utils
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_