diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/BUILD b/mediapipe/tasks/cc/text/utils/xnn_utils/BUILD new file mode 100644 index 000000000..4b58cb8f6 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/BUILD @@ -0,0 +1 @@ +# Utilities needed to interacte with XNNPACK. diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc new file mode 100644 index 000000000..225b5985d --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc @@ -0,0 +1,887 @@ +#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/source_location.h" +#include "file/base/helpers.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" +#include "third_party/XNNPACK/include/xnnpack.h" +#include "util/gtl/stl_logging.h" + +namespace mediapipe { +namespace xnn_utils { +namespace { + +// XNNPACK supports broadcasting, this function inferences the output shape +// based on input tensor shapes. +std::vector OutDimsForElementwiseOp(const Tensor& lhs, + const Tensor& rhs) { + DCHECK(!lhs.dims.empty()); + DCHECK(!rhs.dims.empty()); + std::vector lhs_dims_rev(lhs.dims.rbegin(), lhs.dims.rend()); + std::vector rhs_dims_rev(rhs.dims.rbegin(), rhs.dims.rend()); + DCHECK([&]() -> bool { + for (size_t i = 0; i < std::min(lhs_dims_rev.size(), rhs_dims_rev.size()); + ++i) { + if ((lhs_dims_rev[i] != rhs_dims_rev[i]) && (lhs_dims_rev[i] != 1) && + (rhs_dims_rev[i] != 1)) { + return false; + } + } + return true; + }()) << "lhs " + << lhs.dims << " rhs " << rhs.dims; + std::vector out_dims( + std::max(lhs_dims_rev.size(), rhs_dims_rev.size())); + for (int i = 0; i < out_dims.size(); ++i) { + if (lhs_dims_rev.size() <= i) { + out_dims[i] = rhs_dims_rev[i]; + } else if (rhs_dims_rev.size() <= i) { + out_dims[i] = lhs_dims_rev[i]; + } else { + out_dims[i] = lhs_dims_rev[i] == 1 ? rhs_dims_rev[i] : lhs_dims_rev[i]; + } + } + return std::vector(out_dims.rbegin(), out_dims.rend()); +} + +// If out_id is invalid, we need to allocate tensor for intermediate result. +// Otherwise, set out_id in out_metadata. +absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph, + uint32_t out_id, + Tensor& out_metadata) { + RET_CHECK_GT(out_metadata.dims.size(), 0); + if (out_id == XNN_INVALID_VALUE_ID) { + // The output is intermediate, thus allocate tensor. + MP_RETURN_IF_ERROR(out_metadata.DefineAsIntermediateTensor(*subgraph)); + } else { + out_metadata.tensor_id = out_id; + } + + return absl::OkStatus(); +} + +absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph, + Tensor& out_metadata) { + return MaybeAllocateIntermediateTensor(subgraph, out_metadata.tensor_id, + out_metadata); +} + +absl::Status AllocateIntermediateTensor(xnn_subgraph_t subgraph, + Tensor& out_metadata) { + return MaybeAllocateIntermediateTensor(subgraph, XNN_INVALID_VALUE_ID, + out_metadata); +} + +// 1.0/jax.nn.softplus(0.0) = 1.442695041 +// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1]) +void SoftPlus(size_t cnt, const std::vector& query_dims, float* weight, + float* scale) { + constexpr double r_softplus_0 = 1.442695041; + // softplus(x) = np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0) + // scale = softplus(per_dim_scale) / (sqrt(input.dims[-1]) * softplus(0)) + const double r_softplus_0_over_sqrt_d = + r_softplus_0 / std::sqrt(query_dims.back()); + for (int i = 0; i < cnt; ++i) { + scale[i] = log1p(exp(-abs(weight[i]))) + fmax(weight[i], 0.0f); + scale[i] *= r_softplus_0_over_sqrt_d; + } +} + +} // namespace + +absl::StatusOr> XnnGraphBuilder::Build( + std::unique_ptr runtime_configs) { + if (!runtime_configs) { + runtime_configs = std::make_unique(); + runtime_configs->xnn_num_threads = 1; + runtime_configs->xnn_profile = false; + } + VLOG(2) << "XnnGraphBuilder::Build() building..."; + auto build_begin = absl::Now(); + RET_CHECK_EQ(xnn_status_success, xnn_initialize(nullptr)); + + absl::flat_hash_set> output_tensors; + { + uint32_t cnt = input_tensors_.size(); + for (auto& t : interm_tensors_) { + if (t->is_output_tensor) { + RET_CHECK_EQ(t->tensor_id, XNN_INVALID_VALUE_ID); + t->tensor_id = cnt++; + output_tensors.insert(t); + } + } + for (auto& t : output_tensors) { + interm_tensors_.erase(t); + } + for (auto& t : rope_weigths_) { + interm_tensors_.erase(t); + t->tensor_id = cnt++; + } + } + + xnn_subgraph_t subgraph_ptr = nullptr; + RET_CHECK_EQ(xnn_status_success, + xnn_create_subgraph( + /*external_value_ids=*/input_tensors_.size() + + output_tensors.size() + rope_weigths_.size(), + /*flags=*/0, &subgraph_ptr)); + RET_CHECK_NE(subgraph_ptr, nullptr); + + XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph}; + + for (auto& input : input_tensors_) { + MP_RETURN_IF_ERROR(input->DefineAsInput(*subgraph)); + } + for (auto& output : output_tensors) { + MP_RETURN_IF_ERROR(output->DefineAsOutput(*subgraph)); + } + { + for (auto& t : rope_weigths_) { + MP_RETURN_IF_ERROR(t->DefineRope(*subgraph)); + } + } + + for (auto& [loc, step] : build_steps_) { + if (auto s = step(subgraph.get()); !s.ok()) { + s.AddSourceLocation(loc); + return s; + } + } + + XnnGraph result(std::move(subgraph), std::move(runtime_configs)); + result.input_tensors_ = std::move(input_tensors_); + result.output_tensors_ = std::move(output_tensors); + result.interm_tensors_ = std::move(interm_tensors_); + + VLOG(2) << "XnnGraphBuilder::Build() creating runtime..."; + auto create_begin = absl::Now(); + MP_RETURN_IF_ERROR(result.CreateRuntime()); + VLOG(2) << "XnnGraphBuilder::Build() setting up runtime..."; + auto setup_begin = absl::Now(); + MP_RETURN_IF_ERROR(result.SetupRuntime()); + + auto end = absl::Now(); + VLOG(2) << "XnnGraphBuilder::Build() done build, Total " << end - build_begin + << ", create runtime " << setup_begin - create_begin + << ", setup runtime " << end - setup_begin; + return std::make_unique(std::move(result)); +} + +absl::StatusOr> XnnGraphBuilder::NewInput( + Tensor::DimsType dims, absl::SourceLocation loc) { + auto t = std::make_shared(std::move(dims), data_type_); + t->AllocateBufferIfNeeded(); + t->tensor_id = input_tensors_.size(); + input_tensors_.insert(t); + return t; +} + +absl::StatusOr> XnnGraphBuilder::NewWeight( + absl::string_view file_path, Tensor::DimsType dims, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto t, NewWeight(std::move(dims))); + MP_RETURN_IF_ERROR(t->LoadFromFile(file_path)); + return t; +} + +absl::StatusOr> XnnGraphBuilder::NewWeight( + Tensor::DimsType dims, absl::SourceLocation loc) { + auto t = std::make_shared(std::move(dims), data_type_); + NewWeight(t, loc); + return t; +} + +void XnnGraphBuilder::NewWeight(std::shared_ptr t, + absl::SourceLocation loc) { + build_steps_.push_back( + {loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status { + if (interm_tensors_.contains(t)) { + MP_RETURN_IF_ERROR(t->DefineWeight(*subgraph)); + } + return absl::OkStatus(); + }}); + + interm_tensors_.insert(t); +} + +absl::StatusOr> XnnGraphBuilder::IntermediateTensor( + Tensor::DimsType dims, absl::SourceLocation loc) { + auto t = std::make_shared(std::move(dims), data_type_); + + build_steps_.push_back( + {loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status { + // Could be moved to output tensors, thus need check. + if (interm_tensors_.contains(t)) { + return AllocateIntermediateTensor(subgraph, *t); + } + return absl::OkStatus(); + }}); + + interm_tensors_.insert(t); + return t; +} + +absl::StatusOr> XnnGraphBuilder::Reshape( + std::shared_ptr input, Tensor::DimsType new_dims, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims))); + RET_CHECK_EQ(input->num_elements, output->num_elements) + << "otherwise reshape does not make sense."; + + build_steps_.push_back( + {loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( + subgraph, output->tensor_id, *output)); + + RET_CHECK_EQ(xnn_status_success, + xnn_define_static_reshape( + subgraph, output->dims.size(), output->dims.data(), + input->tensor_id, output->tensor_id, /*flags=*/0)); + return absl::OkStatus(); + }}); + return output; +} + +absl::StatusOr> XnnGraphBuilder::FullConn( + std::shared_ptr input, std::shared_ptr weight, + std::shared_ptr bias, FullConnParams params, + absl::SourceLocation loc) { + const auto& input_dim = input->dims; + const auto& weight_dim = weight->dims; + DCHECK_GT(input_dim.size(), 1); + DCHECK_GE(weight_dim.size(), 2); + if (weight_dim.size() == 3) { + RET_CHECK_EQ(weight_dim[0], 1); + } else if (weight_dim.size() == 4) { + RET_CHECK_EQ(weight_dim[0], 1); + RET_CHECK_EQ(weight_dim[1], 1); + } + if (bias) { + RET_CHECK_LE(bias->dims.size(), 1); + } + + Tensor::DimsType out_dims = input_dim; + // Not considering reshape 2D + if (params.transpose) { + RET_CHECK_EQ(weight_dim.size(), 2) << "otherwise change following line"; + RET_CHECK_EQ(input_dim.back(), *(weight_dim.end() - 2)); + out_dims.back() = weight_dim.back(); + } else { + RET_CHECK_EQ(input_dim.back(), weight_dim.back()); + out_dims.pop_back(); + for (size_t i = 0; i < weight_dim.size() - 1; ++i) { + // NHD . BTD -> NHBT + out_dims.push_back(weight_dim[i]); + } + } + ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(out_dims))); + + build_steps_.push_back( + {loc, + [this, input, weight, bias, params, + output](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( + subgraph, output->tensor_id, *output)); + + RET_CHECK_EQ( + xnn_status_success, + xnn_define_fully_connected( + subgraph, params.out_min, params.out_max, input->tensor_id, + weight->tensor_id, + bias ? bias->tensor_id : XNN_INVALID_VALUE_ID, + output->tensor_id, + /*flags=*/params.transpose ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0)); + + return absl::OkStatus(); + }}); + return output; +} + +absl::StatusOr> XnnGraphBuilder::Permute( + std::shared_ptr input, Tensor::DimsType permute, + absl::SourceLocation loc) { + RET_CHECK_EQ(input->dims.size(), permute.size()); + const auto& old_dims = input->dims; + std::vector new_dims; + for (size_t i = 0; i < permute.size(); ++i) { + new_dims.push_back(old_dims[permute[i]]); + } + ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims))); + + build_steps_.push_back( + {loc, + [this, permute, input, output](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + + RET_CHECK_EQ(xnn_status_success, + xnn_define_static_transpose( + subgraph, permute.size(), permute.data(), + input->tensor_id, output->tensor_id, /*flags=*/0)); + return absl::OkStatus(); + }}); + return output; +} + +absl::StatusOr> XnnGraphBuilder::Square( + std::shared_ptr input, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); + + build_steps_.push_back( + {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( + subgraph, output->tensor_id, *output)); + RET_CHECK_EQ( + xnn_status_success, + xnn_define_square(subgraph, input->tensor_id, output->tensor_id, + /*flags=*/0)); + return absl::Status(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::Softmax( + std::shared_ptr input, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); + + build_steps_.push_back( + {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( + subgraph, output->tensor_id, *output)); + RET_CHECK_EQ( + xnn_status_success, + xnn_define_softmax(subgraph, input->tensor_id, output->tensor_id, + /*flags=*/0)); + return absl::Status(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::SquareRoot( + std::shared_ptr input, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); + + build_steps_.push_back( + {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( + subgraph, output->tensor_id, *output)); + RET_CHECK_EQ(xnn_status_success, + xnn_define_square_root(subgraph, input->tensor_id, + output->tensor_id, + /*flags=*/0)); + return absl::Status(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::AvgLastDim( + std::shared_ptr input, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto before_reshape, + IntermediateTensor(Tensor::DimsType{input->dims.begin(), + input->dims.end() - 1})); + build_steps_.push_back( + {loc, + [this, input, before_reshape](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( + subgraph, before_reshape->tensor_id, *before_reshape)); + size_t reduction_axis = input->dims.size() - 1; + RET_CHECK_EQ( + xnn_status_success, + xnn_define_static_mean(subgraph, 1, &reduction_axis, + input->tensor_id, before_reshape->tensor_id, + /*flags=*/0)); + return absl::OkStatus(); + }}); + + Tensor::DimsType new_dims = input->dims; + new_dims.back() = 1; + return Reshape(before_reshape, std::move(new_dims)); +} + +absl::StatusOr> XnnGraphBuilder::Rms( + std::shared_ptr input, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto sqr_out, Square(input, loc)); + + ASSIGN_OR_RETURN(auto mean_out, AvgLastDim(sqr_out, loc)); + + return SquareRoot(mean_out, loc); +} + +absl::StatusOr> XnnGraphBuilder::RmsNorm( + std::shared_ptr input, std::shared_ptr scale, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto rms_out, Rms(input)); + + ASSIGN_OR_RETURN(auto clamped_rms, Clamp(rms_out, {.out_min = 1e-6})); + + // div_out = input / rms + ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms)); + + // div_out * (1 + scale) = div_out + div_out * scale + ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale)); + + return ElementAdd(div_out, normed_div_out); +} + +absl::StatusOr> XnnGraphBuilder::ElementAdd( + std::shared_ptr lhs, float rhs, ClampParams params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1})); + MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector({rhs}))); + + return ElementAdd(lhs, rhs_tensor, params, loc); +} + +absl::StatusOr> XnnGraphBuilder::ElementAdd( + std::shared_ptr lhs, std::shared_ptr rhs, + ClampParams params, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, + IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs))); + + build_steps_.push_back( + {loc, + [this, lhs, rhs, output, + params](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + RET_CHECK_EQ(xnn_status_success, + xnn_define_add2(subgraph, params.out_min, params.out_max, + lhs->tensor_id, rhs->tensor_id, + output->tensor_id, /*flags=*/0)); + return absl::OkStatus(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::ElementMul( + std::shared_ptr lhs, float rhs, ClampParams params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1})); + MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector({rhs}))); + + return ElementMul(lhs, rhs_tensor, params, loc); +} + +absl::StatusOr> XnnGraphBuilder::ElementMul( + std::shared_ptr lhs, std::shared_ptr rhs, + ClampParams params, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, + IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs))); + + build_steps_.push_back( + {loc, + [this, lhs, rhs, output, + params](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + RET_CHECK_EQ( + xnn_status_success, + xnn_define_multiply2(subgraph, params.out_min, params.out_max, + lhs->tensor_id, rhs->tensor_id, + output->tensor_id, /*flags=*/0)); + return absl::OkStatus(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::ElementDiv( + std::shared_ptr lhs, float rhs, ClampParams params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1})); + MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector({rhs}))); + + return ElementDiv(lhs, rhs_tensor, params, loc); +} + +absl::StatusOr> XnnGraphBuilder::ElementDiv( + std::shared_ptr lhs, std::shared_ptr rhs, + ClampParams params, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, + IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs))); + + build_steps_.push_back( + {loc, + [this, lhs, rhs, output, + params](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + RET_CHECK_EQ( + xnn_status_success, + xnn_define_divide(subgraph, params.out_min, params.out_max, + lhs->tensor_id, rhs->tensor_id, + output->tensor_id, /*flags=*/0)); + return absl::OkStatus(); + }}); + + return output; +} + +// TODO: write an op? +absl::StatusOr> XnnGraphBuilder::PerDimScale( + std::shared_ptr input, std::shared_ptr per_dim_scale, + absl::SourceLocation loc) { + // input: B T N H + // 1/softplus(0) = 1.442695041 + // scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1]) + // query = query * scale + const auto& input_dim = input->dims; + DCHECK_GE(input_dim.size(), 1); + const size_t H = input_dim.back(); + + if (!per_dim_scale_cache_.contains(H) || + !per_dim_scale_cache_[H].contains(per_dim_scale.get())) { + ASSIGN_OR_RETURN(auto cached_pds, NewWeight(per_dim_scale->dims)); + + auto* pds_in = static_cast(per_dim_scale->Data()); + std::vector pds_scaled(per_dim_scale->num_elements); + SoftPlus(per_dim_scale->num_elements, input_dim, pds_in, pds_scaled.data()); + MP_RETURN_IF_ERROR(cached_pds->LoadFromVec(std::move(pds_scaled))); + per_dim_scale_cache_[H][per_dim_scale.get()] = cached_pds; + } + + return ElementMul(input, per_dim_scale_cache_[H][per_dim_scale.get()]); +} + +absl::StatusOr> XnnGraphBuilder::Rope( + std::shared_ptr input, std::shared_ptr segment_pos, + absl::SourceLocation loc) { + // TODO: seg_pos should not be weight. + rope_weigths_.insert(segment_pos); + + const auto& input_dim = input->dims; + const auto& segment_pos_dim = segment_pos->dims; + // B T N H + RET_CHECK_EQ(input_dim.size(), 4) << "xnn requirement"; + // S H + RET_CHECK_EQ(segment_pos_dim.size(), 2) << "xnn requirement"; + + ASSIGN_OR_RETURN(auto output, IntermediateTensor(input_dim)); + + const auto input_seq_size = input_dim[1]; + RET_CHECK_LE(input_seq_size, segment_pos_dim[0]); + const auto head_dim_H = input_dim[3]; + RET_CHECK_EQ(head_dim_H, segment_pos_dim[1]); + + build_steps_.push_back( + {loc, + [this, input, output, segment_pos, + input_seq_size](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + RET_CHECK_EQ( + xnn_status_success, + xnn_define_rope(subgraph, input_seq_size, input->tensor_id, + segment_pos->tensor_id, output->tensor_id, + /*flags=*/0)); + return absl::OkStatus(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::BatchMatMul( + std::shared_ptr input, std::shared_ptr weight, + FullConnParams params, absl::SourceLocation loc) { + const auto& lhs_dim = input->dims; + const auto& rhs_dim = weight->dims; + + // [B, N, T, H] . [B, N, S, H], N == 12, B == 1 + DCHECK_EQ(lhs_dim.size(), 4); + DCHECK_EQ(rhs_dim.size(), 4); + DCHECK_EQ(lhs_dim.back(), rhs_dim.back()); + DCHECK_EQ(lhs_dim.back(), rhs_dim.back()); + constexpr size_t num_slices = 12; + DCHECK_EQ(lhs_dim[1], num_slices); + DCHECK_EQ(rhs_dim[1], num_slices); + const size_t S = rhs_dim[2]; + const size_t T = lhs_dim[2]; + const size_t batch_size = lhs_dim[0] * lhs_dim[1]; + DCHECK_EQ(batch_size, rhs_dim[0] * rhs_dim[1]); + DCHECK_EQ(batch_size, 12); + + ASSIGN_OR_RETURN(auto output, IntermediateTensor({1, 12, T, S})); + + build_steps_.push_back( + {loc, [input, output, weight](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + + RET_CHECK_EQ(xnn_status_success, + xnn_define_batch_matrix_multiply( + subgraph, input->tensor_id, weight->tensor_id, + output->tensor_id, /*flags=*/0)); + + return absl::OkStatus(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::Tanh( + std::shared_ptr input, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); + + build_steps_.push_back( + {loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + + RET_CHECK_EQ(xnn_status_success, + xnn_define_tanh(subgraph, input->tensor_id, + output->tensor_id, /*flags=*/0)); + return absl::OkStatus(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::CapTanh( + std::shared_ptr input, float cap, absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto div, ElementDiv(input, cap)); + ASSIGN_OR_RETURN(auto tanh, Tanh(div)); + return ElementMul(tanh, cap); +} + +absl::StatusOr> XnnGraphBuilder::DotAttention( + std::shared_ptr query_proj, std::shared_ptr key_proj, + std::shared_ptr value_proj, std::shared_ptr atten_mask, + std::shared_ptr per_dim_scale, absl::SourceLocation loc) { + // BTNH + ASSIGN_OR_RETURN(auto query_after_scale, + PerDimScale(query_proj, per_dim_scale)); + + // Dot similarity + // BTNH -> BNTH + ASSIGN_OR_RETURN(auto query_permuted, + Permute(query_after_scale, {0, 2, 1, 3})); + // BSNH -> BNSH + ASSIGN_OR_RETURN(auto key_permuted, Permute(key_proj, {0, 2, 1, 3})); + // einsum(BNTH.BNSH -> BNTS) + ASSIGN_OR_RETURN(auto logits, BatchMatMul(query_permuted, key_permuted)); + + // Cap, mask + ASSIGN_OR_RETURN(auto cap_logits, CapTanh(logits, 50)); + ASSIGN_OR_RETURN(auto padded_logits, ElementAdd(atten_mask, cap_logits)); + ASSIGN_OR_RETURN(auto probs, Softmax(padded_logits)); + ASSIGN_OR_RETURN(auto value_permuted, Permute(value_proj, {0, 2, 3, 1})); + + // Outcome + // BNTS.BNHS -> BNTH + ASSIGN_OR_RETURN(auto outcome_before_permute, + BatchMatMul(probs, value_permuted)); + // [B, N, T, H] -> BTNH + return Permute(outcome_before_permute, {0, 2, 1, 3}); +} + +absl::StatusOr> XnnGraphBuilder::SelfAttentionProj( + std::shared_ptr input, std::shared_ptr weight, + absl::SourceLocation loc) { + const auto& input_dim = input->dims; + const auto& weight_dim = weight->dims; + size_t N = 0, H = 0; + RET_CHECK_EQ(input_dim.size(), 3) << "BTD"; + + std::optional reshaped_N = + weight->GetMetadata(kKeySelfAttentionReshapedWeight); + RET_CHECK(reshaped_N && *reshaped_N) + << "We rely on " << kKeySelfAttentionReshapedWeight << " to get N"; + RET_CHECK_EQ(weight_dim.size(), 2) << "NH,D"; + N = *reshaped_N; + H = weight_dim[0] / N; + + // out: B,T,NH + ASSIGN_OR_RETURN(auto proj, MatMul(input, weight)); + + // B,T,NH -> B,T,N,H + return Reshape(proj, {input_dim[0], input_dim[1], N, H}); +} + +absl::Status XnnGraph::CreateRuntime() { + RET_CHECK_EQ(runtime_.get(), nullptr); + xnn_runtime_t runtime_ptr = nullptr; + uint32_t flags = 0; + if (runtime_configs_->xnn_profile) { + flags |= XNN_FLAG_BASIC_PROFILING; + + if (!runtime_configs_->xnn_profile_csv.empty()) { + MP_RETURN_IF_ERROR(file::SetContents(runtime_configs_->xnn_profile_csv, + "node_id; time(us); op_name\n", + file::Defaults())); + } + } + pthreadpool_t threadpool = + pthreadpool_create(runtime_configs_->xnn_num_threads); + threadpool_ = XnnThreadpoolPtr{threadpool, pthreadpool_destroy}; + + RET_CHECK_EQ(xnn_status_success, + xnn_create_runtime_v2(owned_subgraph_.get(), threadpool, flags, + &runtime_ptr)); + RET_CHECK_NE(runtime_ptr, nullptr); + runtime_ = XnnRuntimePtr{runtime_ptr, xnn_delete_runtime}; + + return absl::OkStatus(); +} + +absl::Status XnnGraph::SetupRuntime() { + { + VLOG(3) << "input size " << input_tensors_.size(); + VLOG(3) << "output size " << output_tensors_.size(); + VLOG(3) << "rope size " << rope_weigths_.size(); + externals_.clear(); + // Init external + for (const auto& input : input_tensors_) { + VLOG(3) << "input id " << input->tensor_id; + externals_.push_back(xnn_external_value{input->tensor_id, input->Data()}); + } + for (const auto& output : output_tensors_) { + VLOG(3) << "output id " << output->tensor_id; + externals_.push_back( + xnn_external_value{output->tensor_id, output->Data()}); + } + for (const auto& t : rope_weigths_) { + VLOG(3) << "rope id " << t->tensor_id; + } + } + RET_CHECK_EQ( + xnn_status_success, + xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data())); + return absl::OkStatus(); +} + +absl::Status XnnGraph::Run() { + RET_CHECK(runtime_); + + RET_CHECK_EQ(xnn_status_success, xnn_invoke_runtime(runtime_.get())); + + if (runtime_configs_->xnn_profile) { + size_t required_size = 0; + + // xnn_get_runtime_profiling_info is called twice. The first time it sets + // required_size to the required size of the buffer to store the result and + // returns xnn_status_out_of_memory. The second time it writes the result to + // the buffer provided that the buffer is large enough and returns + // xnn_status_success. + xnn_status status = xnn_get_runtime_profiling_info( + runtime_.get(), xnn_profile_info_operator_name, /*param_value_size*/ 0, + /*param_value*/ nullptr, &required_size); + std::vector operator_names; + if (status == xnn_status_out_of_memory) { + operator_names.resize(required_size); + status = xnn_get_runtime_profiling_info( + runtime_.get(), xnn_profile_info_operator_name, operator_names.size(), + operator_names.data(), &required_size); + } + RET_CHECK_EQ(status, xnn_status_success); + size_t num_operators; + status = xnn_get_runtime_profiling_info( + runtime_.get(), xnn_profile_info_num_operators, sizeof(num_operators), + &num_operators, &required_size); + RET_CHECK_EQ(status, xnn_status_success); + status = xnn_get_runtime_profiling_info( + runtime_.get(), xnn_profile_info_operator_timing, + /*param_value_size*/ 0, + /*param_value*/ nullptr, &required_size); + std::vector operator_timings; + if (status == xnn_status_out_of_memory) { + operator_timings.resize(required_size / sizeof(uint64_t)); + status = xnn_get_runtime_profiling_info( + runtime_.get(), xnn_profile_info_operator_timing, + operator_timings.size() * sizeof(uint64_t), operator_timings.data(), + &required_size); + } + RET_CHECK_EQ(status, xnn_status_success); + const char* operator_name = nullptr; + size_t name_len = 0; + std::stringstream ss; + for (size_t node_index = 0; node_index < num_operators; ++node_index) { + operator_name = &operator_names[name_len]; + name_len += strlen(operator_name) + 1; + VLOG(2) << "XnnGraphBuilder::Profile() node_index: " << node_index + << ", time: " << operator_timings[node_index] << " us, " + << operator_name << "\n"; + if (!runtime_configs_->xnn_profile_csv.empty()) { + // Use ';' instead of ',' because operator_name contains comma. + ss << node_index << "; " << operator_timings[node_index] << "; " + << operator_name << "\n"; + } + } + if (!runtime_configs_->xnn_profile_csv.empty()) { + MP_RETURN_IF_ERROR(file::AppendStringToFile( + runtime_configs_->xnn_profile_csv, ss.str(), file::Defaults())); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr> XnnGraphBuilder::Clamp( + std::shared_ptr input, ClampParams params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); + + build_steps_.push_back( + {loc, + [this, input, output, params](xnn_subgraph_t subgraph) -> absl::Status { + MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); + + RET_CHECK_EQ(xnn_status_success, + xnn_define_clamp(subgraph, params.out_min, params.out_max, + input->tensor_id, output->tensor_id, + /*flags=*/0)); + return absl::OkStatus(); + }}); + + return output; +} + +absl::StatusOr> XnnGraphBuilder::Gelu( + std::shared_ptr input, absl::SourceLocation loc) { + // x^2 + ASSIGN_OR_RETURN(auto sqr_out, Square(input)); + + // 0.044715 * x^2 + ASSIGN_OR_RETURN(auto sqr_4471, ElementMul(sqr_out, 0.044715)); + + // 1 + 0.044715 * x^2 + ASSIGN_OR_RETURN(auto sqr_4471_1, ElementAdd(sqr_4471, 1.0f)); + + // x + 0.044715 * x^3 + ASSIGN_OR_RETURN(auto x_cube_4471, ElementMul(sqr_4471_1, input)); + + constexpr float sqrt_2_over_pi = 0.7978845608; + ASSIGN_OR_RETURN(auto sqrt_2_over_pi_x_cube_4471, + ElementMul(x_cube_4471, sqrt_2_over_pi)); + + // tanh(x + 0.044715 * x^3) + ASSIGN_OR_RETURN(auto tanh_x_cube_4471, Tanh(sqrt_2_over_pi_x_cube_4471)); + + // 1 + tanh(x + 0.044715 * x^3) + ASSIGN_OR_RETURN(auto tanh_x_cube_4471_1, ElementAdd(tanh_x_cube_4471, 1.0f)); + + // 0.5 * (1 + [tanh(x + 0.044715 * x^3)]) + ASSIGN_OR_RETURN(auto cdf, ElementMul(tanh_x_cube_4471_1, 0.5)); + + return ElementMul(input, cdf); +} + +} // namespace xnn_utils +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h b/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h new file mode 100644 index 000000000..24b7520ba --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h @@ -0,0 +1,288 @@ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/source_location.h" +#include "file/base/helpers.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" +#include "third_party/XNNPACK/include/xnnpack.h" + +namespace mediapipe { +namespace xnn_utils { + +using XnnSubgraphPtr = + std::unique_ptr; +using XnnRuntimePtr = + std::unique_ptr; +using XnnThreadpoolPtr = + std::unique_ptr; + +struct ClampParams { + float out_min = -std::numeric_limits::infinity(); + float out_max = std::numeric_limits::infinity(); +}; + +struct FullConnParams : public ClampParams { + bool transpose = false; +}; + +struct RuntimeConfigs { + bool xnn_profile; + std::string xnn_profile_csv; + size_t xnn_num_threads; +}; + +class XnnGraph; + +// XnnGraphBuilder is used to construct XnnGraph (through Build()). Once a +// XnnGraph is constructed, it can run for multiple times. +class XnnGraphBuilder { + public: + static constexpr absl::string_view kKeySelfAttentionReshapedWeight{ + "self_attention_reshaped_weight_N"}; + + explicit XnnGraphBuilder(xnn_datatype data_type = xnn_datatype_fp32) + : data_type_(data_type) {} + virtual ~XnnGraphBuilder() = default; + + absl::StatusOr> Build( + std::unique_ptr runtime_configs = nullptr); + + // New input or output tensor. + absl::StatusOr> NewInput( + Tensor::DimsType dims, + absl::SourceLocation loc = absl::SourceLocation::current()); + + // New static weight, populate value before Build() + absl::StatusOr> NewWeight( + Tensor::DimsType dims, + absl::SourceLocation loc = absl::SourceLocation::current()); + absl::StatusOr> NewWeight( + absl::string_view file_path, Tensor::DimsType dims, + absl::SourceLocation loc = absl::SourceLocation::current()); + void NewWeight(std::shared_ptr t, + absl::SourceLocation loc = absl::SourceLocation::current()); + + // Element wise square. + absl::StatusOr> Square( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> SquareRoot( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Gelu( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Clamp( + std::shared_ptr input, ClampParams params, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Tanh( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + // logits = cap * jnp.tanh(logits / cap) + absl::StatusOr> CapTanh( + std::shared_ptr input, float cap, + absl::SourceLocation loc = absl::SourceLocation::current()); + + // Average over last dimension, keep num of dims same. + absl::StatusOr> AvgLastDim( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Rms( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> RmsNorm( + std::shared_ptr input, std::shared_ptr scale, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Reshape( + std::shared_ptr input, Tensor::DimsType new_dims, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Permute( + std::shared_ptr input, Tensor::DimsType permute, + absl::SourceLocation loc = absl::SourceLocation::current()); + + // input: [B * I] + // filter: [O * I], [I * O] if transpose + // return: [B * O] + absl::StatusOr> MatMul( + std::shared_ptr input, std::shared_ptr weight, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return MatMul(input, weight, FullConnParams(), loc); + } + + absl::StatusOr> MatMul( + std::shared_ptr input, std::shared_ptr weight, + FullConnParams params, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return FullConn(input, weight, nullptr, params, loc); + } + + absl::StatusOr> BatchMatMul( + std::shared_ptr input, std::shared_ptr weight, + FullConnParams params = FullConnParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> FullConn( + std::shared_ptr input, std::shared_ptr weight, + std::shared_ptr bias, + absl::SourceLocation loc = absl::SourceLocation::current()) { + return FullConn(input, weight, bias, FullConnParams(), loc); + } + + absl::StatusOr> FullConn( + std::shared_ptr input, std::shared_ptr weight, + std::shared_ptr bias, FullConnParams params, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Softmax( + std::shared_ptr input, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> SelfAttentionProj( + std::shared_ptr input, std::shared_ptr weight, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> ElementAdd( + std::shared_ptr lhs, std::shared_ptr rhs, + ClampParams params = ClampParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> ElementAdd( + std::shared_ptr lhs, float rhs, + ClampParams params = ClampParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> ElementMul( + std::shared_ptr lhs, std::shared_ptr rhs, + ClampParams params = ClampParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> ElementMul( + std::shared_ptr lhs, float rhs, + ClampParams params = ClampParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> ElementDiv( + std::shared_ptr lhs, std::shared_ptr rhs, + ClampParams params = ClampParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> ElementDiv( + std::shared_ptr lhs, float rhs, + ClampParams params = ClampParams(), + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> Rope( + std::shared_ptr input, std::shared_ptr segment_pos, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> PerDimScale( + std::shared_ptr input, std::shared_ptr per_dim_scale, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> DotAttention( + std::shared_ptr query_proj, std::shared_ptr key_proj, + std::shared_ptr value_proj, std::shared_ptr atten_mask, + std::shared_ptr per_dim_scale, + absl::SourceLocation loc = absl::SourceLocation::current()); + + protected: + absl::StatusOr> IntermediateTensor( + Tensor::DimsType dims, + absl::SourceLocation loc = absl::SourceLocation::current()); + + const xnn_datatype data_type_; + + std::vector>> + build_steps_; + + absl::flat_hash_set> input_tensors_; + absl::flat_hash_set> interm_tensors_; + + // TODO: fix this. + // This is sort of bug that the weights used for rope has to be defined with + // EXTERNAL flag, but with id out of the external range. + absl::flat_hash_set> rope_weigths_; + + // Caches + absl::flat_hash_map< + size_t /*dim*/, + absl::flat_hash_map>> + per_dim_scale_cache_; +}; + +class XnnGraph { + public: + XnnGraph(XnnSubgraphPtr subgraph, + std::unique_ptr runtime_configs) + : owned_subgraph_(std::move(subgraph)), + runtime_configs_(std::move(runtime_configs)) { + DCHECK(runtime_configs_); + } + XnnGraph(XnnGraph&& other) = default; + virtual ~XnnGraph() = default; + + // xnn_subgraph should be created with same size. + virtual absl::Status Run(); + + protected: + friend class XnnGraphBuilder; + + absl::Status CreateRuntime(); + absl::Status SetupRuntime(); + + XnnSubgraphPtr owned_subgraph_; + + absl::flat_hash_map avg_cache_; + absl::flat_hash_map cap_tanh_cache_; + + // Runtime + std::unique_ptr runtime_configs_; + XnnRuntimePtr runtime_{nullptr, xnn_delete_runtime}; + std::vector externals_; + + XnnThreadpoolPtr threadpool_{nullptr, pthreadpool_destroy}; + + absl::flat_hash_set> input_tensors_; + absl::flat_hash_set> output_tensors_; + // TODO: see above + absl::flat_hash_set> rope_weigths_; + + absl::flat_hash_set> interm_tensors_; +}; + +} // namespace xnn_utils +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc new file mode 100644 index 000000000..f60e53394 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc @@ -0,0 +1,475 @@ +#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/text_generator/calculators/preprocessor_util.h" +#include "mediapipe/tasks/cc/text/text_generator/calculators/sampler_util.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" +#include "util/gtl/stl_logging.h" + +namespace mediapipe { +namespace xnn_utils { +namespace { + +absl::StatusOr> ApplyFinalProj( + std::shared_ptr inter_layer, const UlmWeights& weights, + XnnGraphBuilder& builder) { + return builder.FullConn(inter_layer, weights.softmax_linear, + weights.softmax_bias); +} + +} // namespace + +class OneTokenUlm : public Ulm { + public: + OneTokenUlm(std::unique_ptr full_ulm, XnnGraph&& other) + : Ulm(std::move(other)), full_ulm_(std::move(full_ulm)) {} + ~OneTokenUlm() override = default; + + absl::Status InitInputTokens(const std::vector& input_ids) override { + prev_ids_ = input_ids; + MP_RETURN_IF_ERROR(full_ulm_->InitInputTokens(input_ids)); + // prev_id.size - 1 is the output. + return full_ulm_->Run(); + } + + absl::Status GetNextToken(std::vector* output_ids) override { + size_t decode_step = prev_ids_.size() - 1; + VLOG(2) << "Decode step " << decode_step; + + if (decode_step == ulm_params_.seq_size_T - 1) { + return absl::OutOfRangeError( + absl::StrCat("Hit max sequence length ", ulm_params_.seq_size_T)); + } + + transformer_input_->Borrow( + full_ulm_->transformer_input_->Slice(1, decode_step)); + atten_masks_->Borrow(full_ulm_->atten_masks_->Slice(0, decode_step)); + MP_RETURN_IF_ERROR(segment_pos_->LoadFromBuffer( + full_ulm_->segment_pos_->Slice(0, decode_step)->Data())); + for (auto& kv_cache : kv_cache_) { + DCHECK(kv_cache.k_slice); + DCHECK(kv_cache.v_slice); + kv_cache.k_slice->Borrow(kv_cache.k_cache->Slice(1, decode_step)); + kv_cache.v_slice->Borrow(kv_cache.v_cache->Slice(1, decode_step)); + } + + MP_RETURN_IF_ERROR(SetupRuntime()); + MP_RETURN_IF_ERROR(Run()); + + RET_CHECK(logits_output_); + DCHECK_EQ(logits_output_->num_elements, ulm_params_.voc_size_V); + + ASSIGN_OR_RETURN(*output_ids, + mediapipe::SampleNextToken( + logits_output_->DataAs(), + /*batch_size=*/1, + /*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10, + /*top_p=*/1, /*temperature=*/-1)); + RET_CHECK_EQ(output_ids->size(), 1); + prev_ids_.push_back(output_ids->at(0)); + + return GetTokenEmbedding( + *output_ids, + pos_embedding_data_->Slice({decode_step + 1, 0})->DataAs(), + full_ulm_->transformer_input_->Slice({0, decode_step + 1, 0}) + ->DataAs()); + } + + private: + std::unique_ptr full_ulm_; +}; + +absl::StatusOr> UlmBuilder::SelfAttentionExcludeNorm( + std::shared_ptr input, SelfAttentionArgs args, + const SelfAttentionWeights& sa_weights, absl::SourceLocation loc) { + // [B, 1|T, N, H] + ASSIGN_OR_RETURN(auto k_proj, SelfAttentionProj(input, sa_weights.k_weight)); + ASSIGN_OR_RETURN(auto q_proj, SelfAttentionProj(input, sa_weights.q_weight)); + ASSIGN_OR_RETURN(auto v_proj, SelfAttentionProj(input, sa_weights.v_weight)); + + ASSIGN_OR_RETURN(auto query_proj_after_rope, Rope(q_proj, args.segment_pos)); + ASSIGN_OR_RETURN(auto key_proj_after_rope, Rope(k_proj, args.segment_pos)); + + if (args.cache) { + RET_CHECK(args.cache->k_cache); + RET_CHECK(args.cache->v_cache); + // When cache is provided, there are 2 cases: + if (*(input->dims.end() - 2) != 1) { + // Building a normal graph, which is used to initialize cache. + key_proj_after_rope->Borrow(args.cache->k_cache).MarkOutput(); + v_proj->Borrow(args.cache->v_cache).MarkOutput(); + } else { + // Building a one-token graph, which consumes initialized cache. + key_proj_after_rope->MarkOutput(); + args.cache->k_slice = key_proj_after_rope; + v_proj->MarkOutput(); + args.cache->v_slice = v_proj; + + ASSIGN_OR_RETURN(key_proj_after_rope, + NewInput(args.cache->k_cache->dims)); + key_proj_after_rope->Borrow(args.cache->k_cache); + ASSIGN_OR_RETURN(v_proj, NewInput(args.cache->v_cache->dims)); + v_proj->Borrow(args.cache->v_cache); + } + } + + // encoded, [B, 1|T, N, H] + ASSIGN_OR_RETURN( + auto kqv_merged, + DotAttention(query_proj_after_rope, key_proj_after_rope, v_proj, + args.atten_mask, sa_weights.per_dim_scale)); + + const size_t B = kqv_merged->dims[0]; + const size_t T_or_1 = kqv_merged->dims[1]; + const size_t NH = kqv_merged->num_elements / (B * T_or_1); + ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, T_or_1, NH})); + + return MatMul(outcome_reshaped, sa_weights.post_proj_weight, + {.transpose = false}); +} + +absl::StatusOr> +UlmBuilder::SelfAttentionIncludeResidual(std::shared_ptr input, + SelfAttentionArgs args, + const SelfAttentionWeights& params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto pre_attention, RmsNorm(input, params.pre_norm)); + + ASSIGN_OR_RETURN( + auto post_attention, + SelfAttentionExcludeNorm(pre_attention, std::move(args), params)); + + ASSIGN_OR_RETURN(auto post_norm, RmsNorm(post_attention, params.post_norm)); + + return ElementAdd(input, post_norm); +} + +absl::StatusOr> UlmBuilder::FeedForwardExcludeResidual( + std::shared_ptr input, const FeedForwardWeights& params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto first_rms_norm, RmsNorm(input, params.pre_norm)); + + ASSIGN_OR_RETURN(auto layer_1, FullConn(first_rms_norm, params.layer_1_weight, + params.layer_1_bias)); + + ASSIGN_OR_RETURN(auto layer_1_gate_before_gelu, + FullConn(first_rms_norm, params.layer_1_gate_weight, + params.layer_1_gate_bias)); + ASSIGN_OR_RETURN(auto layer_1_gate, Gelu(layer_1_gate_before_gelu)); + + ASSIGN_OR_RETURN(auto layer_1_and_gate, ElementMul(layer_1, layer_1_gate)); + if (params.opt_padding) { + // activations *= 1.0 - paddings + ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f)); + ASSIGN_OR_RETURN(tmp, ElementMul(layer_1_and_gate, tmp)); + ASSIGN_OR_RETURN(layer_1_and_gate, ElementAdd(tmp, layer_1_and_gate)); + } + ASSIGN_OR_RETURN( + auto layer_2, + FullConn(layer_1_and_gate, params.layer_2_weight, params.layer_2_bias)); + if (params.opt_padding) { + // activations *= 1.0 - paddings + ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f)); + ASSIGN_OR_RETURN(tmp, ElementMul(layer_2, tmp)); + ASSIGN_OR_RETURN(layer_2, ElementAdd(tmp, layer_2)); + } + + return RmsNorm(layer_2, params.post_norm); +} + +absl::StatusOr> UlmBuilder::FeedForwardIncludeResidual( + std::shared_ptr input, const FeedForwardWeights& params, + absl::SourceLocation loc) { + ASSIGN_OR_RETURN(auto before_residual, + FeedForwardExcludeResidual(input, params)); + return ElementAdd(before_residual, input); +} + +absl::StatusOr> Ulm::CreateUlm( + absl::string_view weights_folder, const UlmParams& ulm_params, + std::unique_ptr runtime_configs) { + auto weight_loader = + std::make_unique(weights_folder, ulm_params); + return CreateUlm(std::move(weight_loader), std::move(runtime_configs)); +} + +absl::StatusOr> Ulm::CreateOneTokenUlm( + std::unique_ptr weight_loader, + std::unique_ptr runtime_configs) { + UlmBuilder builder; + // TODO: might be memory waste here, benchmark. + weight_loader->SetBuilder(builder); + ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights()); + + UlmParams ulm_params = weight_loader->ulm_params(); + ulm_params.enable_kv_cache = true; + + weight_loader->ulm_params().enable_kv_cache = true; + weight_loader->ulm_params().final_norm = false; + weight_loader->ulm_params().final_project = false; + ASSIGN_OR_RETURN(auto full_ulm, CreateUlm(std::move(weight_loader))); + + ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, 1, + ulm_params.model_dim_D})); + ASSIGN_OR_RETURN(auto atten_masks, + builder.NewInput({1, ulm_params.seq_size_T})); + ASSIGN_OR_RETURN(auto segment_pos, + builder.NewWeight({1, ulm_params.head_dim_H})); + // To allocate buffer before creating runtime. + MP_RETURN_IF_ERROR(segment_pos->LoadFromVec({}, /*exact_match=*/false)); + + std::vector& kv_cache = full_ulm->kv_cache_; + RET_CHECK_EQ(kv_cache.size(), ulm_params.num_transformer_M); + + auto inter_layer = input; + for (int i = 0; i < ulm_params.num_transformer_M; ++i) { + const auto& sa = weights.sas[i]; + ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual( + inter_layer, + {.atten_mask = atten_masks, + .segment_pos = segment_pos, + .cache = &kv_cache[i]}, + sa)); + + auto& ff = weights.ffs[i]; + // ff.opt_padding = paddings; + ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff)); + } + + std::shared_ptr logits_output, transformer_output, normed_output; + + if (ulm_params.final_norm) { + ASSIGN_OR_RETURN(inter_layer, + builder.RmsNorm(inter_layer, weights.final_ln_scale)); + normed_output = inter_layer; + normed_output->MarkOutput(); + } + if (ulm_params.final_project) { + RET_CHECK(weights.softmax_linear); + ASSIGN_OR_RETURN(logits_output, + ApplyFinalProj(inter_layer, weights, builder)); + logits_output->MarkOutput(); + } + + ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs))); + Ulm* full_ulm_p = full_ulm.get(); + auto result = + std::make_unique(std::move(full_ulm), std::move(*graph)); + { + Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D}; + result->pos_embedding_data_ = + std::make_shared(std::move(dims), xnn_datatype_fp32); + result->pos_embedding_data_->Borrow(full_ulm_p->pos_embedding_data_); + } + result->transformer_input_ = input; + result->transformer_output_ = transformer_output; + result->normed_output_ = normed_output; + result->logits_output_ = logits_output; + result->segment_pos_ = segment_pos; + result->atten_masks_ = atten_masks; + if (ulm_params.use_padding) { + // result->paddings_ = paddings; + } + result->kv_cache_ = std::move(kv_cache); + + result->weights_ = std::move(weights); + result->ulm_params_ = ulm_params; + + return result; +} + +absl::StatusOr> Ulm::CreateUlm( + std::unique_ptr weight_loader, + std::unique_ptr runtime_configs) { + UlmBuilder builder; + weight_loader->SetBuilder(builder); + const auto& ulm_params = weight_loader->ulm_params(); + RET_CHECK_NE(ulm_params.batch_size_B, 0); + + ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, + ulm_params.seq_size_T, + ulm_params.model_dim_D})); + ASSIGN_OR_RETURN(auto atten_masks, builder.NewInput({ulm_params.seq_size_T, + ulm_params.seq_size_T})); + VLOG(1) << "atten mask id " << atten_masks->tensor_id; + ASSIGN_OR_RETURN( + auto segment_pos, + builder.NewWeight({ulm_params.seq_size_T, ulm_params.head_dim_H})); + MP_RETURN_IF_ERROR(FillXnnRoPEWeights(*segment_pos)); + VLOG(1) << "segment pos id " << segment_pos->tensor_id; + std::shared_ptr paddings; + if (ulm_params.use_padding) { + ASSIGN_OR_RETURN(paddings, builder.NewInput({ulm_params.batch_size_B, + ulm_params.seq_size_T, 1})); + VLOG(1) << "paddings id " << paddings->tensor_id; + } + + ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights()); + std::vector kv_cache; + + auto inter_layer = input; + for (int i = 0; i < ulm_params.num_transformer_M; ++i) { + const auto& sa = weights.sas[i]; + KVCache* cache = nullptr; + if (ulm_params.enable_kv_cache) { + auto k_cache = std::make_shared( + Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T, + ulm_params.n_heads_N, ulm_params.head_dim_H}); + MP_RETURN_IF_ERROR(k_cache->LoadFromVec({}, /*exact_match=*/false)); + auto v_cache = std::make_shared( + Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T, + ulm_params.n_heads_N, ulm_params.head_dim_H}); + MP_RETURN_IF_ERROR(v_cache->LoadFromVec({}, /*exact_match=*/false)); + kv_cache.push_back(KVCache{.k_cache = k_cache, .v_cache = v_cache}); + cache = &kv_cache.back(); + } + ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual( + inter_layer, + {.atten_mask = atten_masks, + .segment_pos = segment_pos, + .cache = cache}, + sa)); + + auto& ff = weights.ffs[i]; + ff.opt_padding = paddings; + ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff)); + } + + std::shared_ptr logits_output, transformer_output, normed_output; + + if (!ulm_params.final_norm && !ulm_params.final_project) { + transformer_output = inter_layer; + transformer_output->MarkOutput(); + } + + if (ulm_params.final_norm) { + ASSIGN_OR_RETURN(inter_layer, + builder.RmsNorm(inter_layer, weights.final_ln_scale)); + normed_output = inter_layer; + normed_output->MarkOutput(); + } + + if (ulm_params.final_project) { + RET_CHECK(weights.softmax_linear); + ASSIGN_OR_RETURN(logits_output, + ApplyFinalProj(inter_layer, weights, builder)); + logits_output->MarkOutput(); + } + + ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs))); + auto ulm = std::make_unique(std::move(*graph)); + { + ASSIGN_OR_RETURN(auto pos_embedding_data, + mediapipe::PositionEmbedding(ulm_params.seq_size_T, + ulm_params.model_dim_D)); + Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D}; + ulm->pos_embedding_data_ = + std::make_shared(std::move(dims), xnn_datatype_fp32); + MP_RETURN_IF_ERROR( + ulm->pos_embedding_data_->LoadFromVec(pos_embedding_data)); + } + ulm->transformer_input_ = input; + ulm->transformer_output_ = transformer_output; + ulm->normed_output_ = normed_output; + ulm->logits_output_ = logits_output; + ulm->segment_pos_ = segment_pos; + ulm->atten_masks_ = atten_masks; + if (ulm_params.use_padding) { + ulm->paddings_ = paddings; + } + ulm->kv_cache_ = std::move(kv_cache); + + ulm->weights_ = std::move(weights); + ulm->ulm_params_ = ulm_params; + + return ulm; +} + +absl::Status Ulm::InitInputTokens(const std::vector& input_ids) { + prev_ids_ = input_ids; + + constexpr float neg_value = 0.7 * std::numeric_limits::lowest(); + const auto& seq_size = ulm_params_.seq_size_T; + std::vector attention_array(seq_size * seq_size, neg_value); + for (int i = 0; i < seq_size; ++i) { + for (int j = 0; j < seq_size; ++j) { + if (i < input_ids.size() && j < input_ids.size()) { + attention_array[seq_size * i + j] = 0; + } else if (i >= seq_size && j <= i) { + attention_array[seq_size * i + j] = 0; + } else { + break; + } + } + } + + MP_RETURN_IF_ERROR(atten_masks_->LoadFromVec(attention_array)); + + MP_RETURN_IF_ERROR(GetTokenEmbedding(input_ids, + pos_embedding_data_->DataAs(), + transformer_input_->DataAs())); + return SetupRuntime(); +} + +absl::Status Ulm::GetNextToken(std::vector* output_ids) { + VLOG(2) << "Decode step " << prev_ids_.size() - 1; + + MP_RETURN_IF_ERROR(Run()); + + RET_CHECK(logits_output_); + std::shared_ptr logits = + logits_output_->Slice({0, prev_ids_.size() - 1, 0}); + DCHECK_EQ(logits->num_elements, ulm_params_.voc_size_V); + + ASSIGN_OR_RETURN(*output_ids, + mediapipe::SampleNextToken( + logits->DataAs(), + /*batch_size=*/1, + /*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10, + /*top_p=*/1, /*temperature=*/-1)); + RET_CHECK_EQ(output_ids->size(), 1); + prev_ids_.push_back(output_ids->at(0)); + + return GetTokenEmbedding( + *output_ids, + pos_embedding_data_->Slice({prev_ids_.size() - 1, 0})->DataAs(), + transformer_input_->Slice({0, prev_ids_.size() - 1, 0})->DataAs()); +} + +absl::Status Ulm::GetTokenEmbedding(const std::vector& ids, + const float* pos_embedding_data, + float* embedding) { + auto token_embedding = weights_.token_embedding ? weights_.token_embedding + : weights_.softmax_linear; + RET_CHECK(token_embedding->dims[0] == ulm_params_.voc_size_V) + << "shape must be [vocab_size, _], such that following Slice() makes " + "sense."; + for (size_t id : ids) { + memcpy(embedding, token_embedding->Slice(0, id)->Data(), + ulm_params_.model_dim_D * sizeof(float)); + for (size_t i = 0; i < ulm_params_.model_dim_D; ++i) { + embedding[i] += pos_embedding_data[i]; + } + pos_embedding_data += ulm_params_.model_dim_D; + embedding += ulm_params_.model_dim_D; + } + return absl::OkStatus(); +} + +} // namespace xnn_utils +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h new file mode 100644 index 000000000..7bf7de5a9 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h @@ -0,0 +1,127 @@ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" + +namespace mediapipe { +namespace xnn_utils { + +class Ulm : public XnnGraph { + public: + using UlmParams = UlmParams; + + explicit Ulm(XnnGraph&& other) : XnnGraph(std::move(other)) {} + ~Ulm() override = default; + + // Creating ULM graph with default params. The default param corresponds to + // ULM1B 256k model. + static absl::StatusOr> CreateUlm( + absl::string_view weights_folder, + const UlmParams& ulm_params = + UlmParams{ + .num_transformer_M = 18, + .batch_size_B = 1, + .seq_size_T = 16, + .model_dim_D = 1536, + .hidden_dim_HD = 8 * 1536, + .head_dim_H = 128, + .n_heads_N = 12, + .voc_size_V = 256128, + }, + std::unique_ptr runtime_configs = nullptr); + static absl::StatusOr> CreateUlm( + std::unique_ptr weight_loader, + std::unique_ptr runtime_configs = nullptr); + // Build the graph for one-token inference. + static absl::StatusOr> CreateOneTokenUlm( + std::unique_ptr weight_loader, + std::unique_ptr runtime_configs = nullptr); + + // (Re)Initialize with input token ids. This will reset the cache, mask etc. + virtual absl::Status InitInputTokens(const std::vector& input_ids); + + // Get the next token id. + virtual absl::Status GetNextToken(std::vector* output_ids); + + protected: + friend class OneTokenUlm; + friend class UlmTest; + friend class UlmBuilder; + + // Enable if enable_kv_cache + struct KVCache { + std::shared_ptr k_cache; + std::shared_ptr v_cache; + std::shared_ptr k_slice; + std::shared_ptr v_slice; + }; + + absl::Status GetTokenEmbedding(const std::vector& ids, + const float* pos_embedding_data, + float* embedding); + + UlmWeights weights_; + UlmParams ulm_params_; + + std::shared_ptr pos_embedding_data_; + std::shared_ptr atten_masks_; + std::shared_ptr segment_pos_; + std::shared_ptr paddings_; + + std::shared_ptr transformer_input_; + std::shared_ptr transformer_output_; + std::shared_ptr normed_output_; + std::shared_ptr logits_output_; + + // Previous ids, including prompt. + std::vector prev_ids_; + // If enable_kv_cache, expect a mask of [0, ... 0, 1, 0, 0...], size 1 x T. + std::shared_ptr decode_step_mask_; + // [1, 1, ..., 1, 0, 0...], applied on cache + std::shared_ptr decode_step_mask_for_cache_; + std::vector kv_cache_; +}; + +class UlmBuilder : public XnnGraphBuilder { + public: + struct SelfAttentionArgs { + std::shared_ptr atten_mask; + std::shared_ptr segment_pos; + + Ulm::KVCache* cache = nullptr; + }; + + absl::StatusOr> SelfAttentionExcludeNorm( + std::shared_ptr input, SelfAttentionArgs args, + const SelfAttentionWeights& sa_weights, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> SelfAttentionIncludeResidual( + std::shared_ptr input, SelfAttentionArgs args, + const SelfAttentionWeights& params, + absl::SourceLocation loc = absl::SourceLocation::current()); + + absl::StatusOr> FeedForwardExcludeResidual( + std::shared_ptr input, const FeedForwardWeights& params, + absl::SourceLocation loc = absl::SourceLocation::current()); + absl::StatusOr> FeedForwardIncludeResidual( + std::shared_ptr input, const FeedForwardWeights& params, + absl::SourceLocation loc = absl::SourceLocation::current()); +}; + +} // namespace xnn_utils +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc new file mode 100644 index 000000000..a33589a60 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc @@ -0,0 +1,366 @@ +#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "file/base/filesystem.h" +#include "file/base/options.h" +#include "file/base/path.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" +#include "third_party/XNNPACK/include/xnnpack.h" + +namespace mediapipe { +namespace xnn_utils { + +namespace { + +absl::StatusOr> LoadFromAbsPathPrefixHelper( + XnnGraphBuilder& builder, absl::string_view prefix, + const Tensor::DimsType& dims, size_t dim_scale_if_any) { + RET_CHECK(!prefix.empty() && prefix.back() != '.'); + std::vector filenames; + auto s = file::Match(absl::StrCat(prefix, "*"), &filenames, file::Defaults()); + if (!s.ok()) { + LOG(WARNING) << s; + return nullptr; + } else if (filenames.empty()) { + return nullptr; + } + + if (filenames.size() == 1) { + RET_CHECK_EQ(filenames[0], prefix); + return builder.NewWeight(filenames[0], dims); + } + + bool is_quantized_tensor = false; + for (const auto& filename : filenames) { + if (absl::StrContains(filename, kQuantizedScaleSuffix)) { + is_quantized_tensor = true; + continue; + } + } + + RET_CHECK(is_quantized_tensor) + << "At least one of {" << filenames << "} must be quantize scale file."; + + std::shared_ptr result; + result = std::make_shared(dims, dim_scale_if_any); + + MP_RETURN_IF_ERROR(result->LoadFromFile(prefix)); + builder.NewWeight(result); + + return result; +} + +absl::Status TransposeSelfAttentionWeight( + const UlmWeightsLoader& loader, std::shared_ptr& original_weight, + absl::string_view cache_file_prefix) { + const auto& ulm_param = loader.ulm_params(); + RET_CHECK(original_weight); + + std::optional from_cache = + original_weight->GetMetadata(UlmWeights::kKeyLoadedFromCache); + if (from_cache && *from_cache) { + return absl::OkStatus(); + } + + if (auto s = original_weight->DumpToFile(cache_file_prefix); !s.ok()) { + LOG(WARNING) << s; + } else { + MP_RETURN_IF_ERROR(original_weight->LoadFromFile(cache_file_prefix)); + } + loader.builder().NewWeight(original_weight); + original_weight->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight, + ulm_param.n_heads_N); + return absl::OkStatus(); +} + +} // namespace + +absl::Status PrepareTokenEmbeddingDecorator::Decorate( + const UlmWeightsLoader& loader, UlmWeights& weight) { + if (weight.token_embedding) { + return absl::OkStatus(); + } + + const auto& ulm_params = loader.ulm_params(); + absl::string_view cache_path = loader.ulm_params().weight_cache_path; + std::string token_embedding_cache_path = + cache_path.empty() ? "" : file::JoinPath(cache_path, "token_embedding.w"); + // 1. try cache + if (!token_embedding_cache_path.empty()) { + auto token_embedding = + Tensor::FromFile(token_embedding_cache_path, + {ulm_params.voc_size_V, ulm_params.model_dim_D}); + if (token_embedding.ok()) { + weight.token_embedding = *token_embedding; + return absl::OkStatus(); + } + } + + // 2. fill embedding from softmax_linear + auto& softmax_linear = *weight.softmax_linear; + RET_CHECK(softmax_linear.dims[0] == ulm_params.voc_size_V) << softmax_linear; + if (softmax_linear.datatype == xnn_datatype_fp32) { + weight.token_embedding = softmax_linear.View(); + } else if (softmax_linear.datatype == xnn_datatype_qcint8) { + ASSIGN_OR_RETURN(weight.token_embedding, softmax_linear.ConvertToF32()); + } + + float* embedding_data = weight.token_embedding->DataAs(); + for (size_t i = 0; i < softmax_linear.num_elements; ++i) { + embedding_data[i] *= std::sqrt(loader.ulm_params().model_dim_D); + } + + // 3. save cache + if (!token_embedding_cache_path.empty()) { + MP_RETURN_IF_ERROR( + weight.token_embedding->DumpToFile(token_embedding_cache_path)); + return weight.token_embedding->LoadFromFile(token_embedding_cache_path); + } + + return absl::OkStatus(); +} + +absl::Status TransposeSelfAttentionWeightDecorator::Decorate( + const UlmWeightsLoader& loader, UlmWeights& weight) { + absl::string_view cache_path = loader.ulm_params().weight_cache_path; + if (cache_path.empty()) { + return absl::OkStatus(); + } + + for (size_t i = 0; i < weight.sas.size(); ++i) { + auto& sa = weight.sas[i]; + auto prefix = absl::StrCat(UlmWeightsLoader::kTransformerWeightPrefix, i, + ".self_attention."); + MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight( + loader, sa.k_weight, + file::JoinPath(cache_path, absl::StrCat(prefix, "k.w")))); + MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight( + loader, sa.q_weight, + file::JoinPath(cache_path, absl::StrCat(prefix, "q.w")))); + MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight( + loader, sa.v_weight, + file::JoinPath(cache_path, absl::StrCat(prefix, "v.w")))); + } + + return absl::OkStatus(); +} + +absl::StatusOr> UlmWeightsLoader::LoadFromAbsPathPrefix( + absl::string_view prefix, const Tensor::DimsType& dims, + size_t dim_scale_if_any) const { + return LoadFromAbsPathPrefixHelper(*builder_, prefix, dims, dim_scale_if_any); +} + +absl::StatusOr> +UlmWeightsLoader::TryCacheThenLoadSelfAttention( + absl::string_view filename_prefix) const { + ASSIGN_OR_RETURN( + auto r, + TryCacheThenLoadWeightTranspose( + filename_prefix, + {params_.model_dim_D, params_.n_heads_N * params_.head_dim_H}, 1)); + r->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight, + params_.n_heads_N); + return r; +} + +absl::StatusOr> +UlmWeightsLoader::TryCacheThenLoadFeedForward( + absl::string_view filename_prefix, + std::optional dims) const { + if (!dims) { + dims = {params_.model_dim_D, params_.hidden_dim_HD}; + } + return TryCacheThenLoadWeightTranspose(filename_prefix, *dims, 1); +} + +absl::StatusOr> +UlmWeightsLoader::TryCacheThenLoadWeightTranspose( + absl::string_view filename_prefix, Tensor::DimsType original_dims, + size_t original_dim_cale) const { + if (!params_.weight_cache_path.empty()) { + auto cache_full_prefix = + file::JoinPath(params_.weight_cache_path, filename_prefix); + Tensor::DimsType cache_dim{original_dims.rbegin(), original_dims.rend()}; + ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix( + cache_full_prefix, std::move(cache_dim), + /*dim_scale_if_any=*/1 - original_dim_cale)); + if (r) { + r->SetMetadata(UlmWeights::kKeyLoadedFromCache, 1); + return r; + } + } + + ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix( + file::JoinPath(weight_path_, filename_prefix), + std::move(original_dims), + /*dim_scale_if_any=*/original_dim_cale)); + RET_CHECK(r) << file::JoinPath(weight_path_, filename_prefix); + r = r->Transpose(); + builder_->NewWeight(r); + return r; +} + +absl::StatusOr UlmWeightsLoader::LoadFeedForward( + int layer_id) { + absl::string_view weights_folder = weight_path_; + const auto& params = params_; + auto ff_file_prefix = + absl::StrCat(kTransformerWeightPrefix, layer_id, ".ff_layer."); + auto ff_prefix = file::JoinPath(weights_folder, ff_file_prefix); + FeedForwardWeights feed_forward; + + ASSIGN_OR_RETURN( + feed_forward.pre_norm, + LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "pre_layer_norm.scale"), + {params.model_dim_D})); + ASSIGN_OR_RETURN( + feed_forward.post_norm, + LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "post_layer_norm.scale"), + {params.model_dim_D})); + ASSIGN_OR_RETURN( + feed_forward.layer_1_bias, + LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1.bias.b"), + {params.hidden_dim_HD})); + ASSIGN_OR_RETURN(feed_forward.layer_1_weight, + TryCacheThenLoadFeedForward( + absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w"))); + ASSIGN_OR_RETURN( + feed_forward.layer_1_gate_bias, + LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1_gate.bias.b"), + {params.hidden_dim_HD})); + ASSIGN_OR_RETURN(feed_forward.layer_1_gate_weight, + TryCacheThenLoadFeedForward(absl::StrCat( + ff_file_prefix, "ffn_layer1_gate.linear.w"))); + ASSIGN_OR_RETURN( + feed_forward.layer_2_bias, + LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer2.bias.b"), + {params.model_dim_D}, /*dim_scale_if_any=*/0)); + ASSIGN_OR_RETURN( + feed_forward.layer_2_weight, + TryCacheThenLoadFeedForward( + absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"), + Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D})); + + return feed_forward; +} + +absl::StatusOr UlmWeightsLoader::LoadSelfAttention( + int layer_id) { + absl::string_view weights_folder = weight_path_; + const auto& params = params_; + SelfAttentionWeights self_attention; + + auto sa_file_prefix = absl::StrCat(kTransformerWeightPrefix, layer_id); + auto sa_prefix = file::JoinPath(weights_folder, sa_file_prefix); + ASSIGN_OR_RETURN( + self_attention.pre_norm, + LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".pre_layer_norm.scale"), + {params.model_dim_D})); + ASSIGN_OR_RETURN( + self_attention.post_norm, + LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".post_layer_norm.scale"), + {params.model_dim_D})); + + absl::StrAppend(&sa_file_prefix, ".self_attention."); + + ASSIGN_OR_RETURN( + self_attention.k_weight, + TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w"))); + ASSIGN_OR_RETURN( + self_attention.q_weight, + TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w"))); + ASSIGN_OR_RETURN( + self_attention.v_weight, + TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w"))); + + sa_prefix = file::JoinPath(weights_folder, sa_file_prefix); + ASSIGN_OR_RETURN(self_attention.per_dim_scale, + LoadFromAbsPathPrefix( + absl::StrCat(sa_prefix, "per_dim_scale.per_dim_scale"), + {params.head_dim_H})); + ASSIGN_OR_RETURN(self_attention.post_proj_weight, + LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, "post.w"), + {params.model_dim_D, + params.n_heads_N * params.head_dim_H}, + /*dim_scale_if_any=*/0)); + + return self_attention; +} + +absl::StatusOr UlmWeightsLoader::LoadWeights() { + absl::string_view weights_folder = weight_path_; + const auto& params = params_; + UlmWeights result; + + for (int layer_id = 0; layer_id < params.num_transformer_M; ++layer_id) { + ASSIGN_OR_RETURN(auto ff, LoadFeedForward(layer_id)); + result.ffs.push_back(std::move(ff)); + ASSIGN_OR_RETURN(auto sa, LoadSelfAttention(layer_id)); + result.sas.push_back(std::move(sa)); + } + if (params.final_norm) { + ASSIGN_OR_RETURN(result.final_ln_scale, + LoadFromAbsPathPrefix( + file::JoinPath(weights_folder, kFinalScaleFilename), + {params.model_dim_D})); + } + ASSIGN_OR_RETURN(result.softmax_bias, + LoadFromAbsPathPrefix( + file::JoinPath(weights_folder, kLogitsFfnBiasFilename), + {params.voc_size_V})); + ASSIGN_OR_RETURN(result.softmax_linear, + TryCacheThenLoadWeightTranspose( + kLogitsFfnWeightFilename, + {params.model_dim_D, params.voc_size_V}, 1)); + + return result; +} + +BenchmarkUlmWeightsLoader::BenchmarkUlmWeightsLoader(const UlmParams& params, + xnn_datatype data_type) + : DefaultUlmWeightsLoader("", params), data_type_(data_type) { + params_.weight_cache_path.clear(); +} + +absl::StatusOr> +BenchmarkUlmWeightsLoader::TryCacheThenLoadWeightTranspose( + absl::string_view filename_prefix, Tensor::DimsType original_dims, + size_t original_dim_cale) const { + auto result = std::make_shared( + Tensor::DimsType{original_dims.rbegin(), original_dims.rend()}, + 1 - original_dim_cale); + auto real_data = std::make_shared(result->num_elements, 0xA5); + result->flat_data = std::shared_ptr(real_data, real_data->data()); + auto real_scale = std::make_shared>( + original_dims[original_dim_cale], 1.0f); + result->scale_data = std::shared_ptr(real_scale, real_scale->data()); + builder_->NewWeight(result); + return result; +} + +absl::StatusOr> +BenchmarkUlmWeightsLoader::LoadFromAbsPathPrefix( + absl::string_view prefix, const Tensor::DimsType& dims, + size_t dim_scale_if_any) const { + // If loader calls this function directly, it's always non-quantized weights. + auto result = std::make_shared(dims); + MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false)); + builder_->NewWeight(result); + return result; +} + +} // namespace xnn_utils +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h new file mode 100644 index 000000000..f10d8706a --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h @@ -0,0 +1,192 @@ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" +#include "third_party/XNNPACK/include/xnnpack.h" + +namespace mediapipe { +namespace xnn_utils { + +struct UlmParams { + size_t num_transformer_M = 18; + size_t batch_size_B = 1; + size_t seq_size_T = 16; + size_t model_dim_D = 1536; + size_t hidden_dim_HD = 8 * 1536; + size_t head_dim_H = 128; + size_t n_heads_N = 12; + size_t voc_size_V = 32000; + + bool use_padding = true; + bool final_norm = true; + bool final_project = true; + + bool enable_kv_cache = false; + // Path to store reshaped weights as cache. Set empty to disable caching. + std::string weight_cache_path; +}; + +struct SelfAttentionWeights { + std::shared_ptr pre_norm; + + std::shared_ptr k_weight; + std::shared_ptr q_weight; + std::shared_ptr v_weight; + std::shared_ptr per_dim_scale; + std::shared_ptr post_proj_weight; + + std::shared_ptr post_norm; +}; + +struct FeedForwardWeights { + std::shared_ptr pre_norm; + std::shared_ptr layer_1_weight; + std::shared_ptr layer_1_bias; + std::shared_ptr layer_1_gate_weight; + std::shared_ptr layer_1_gate_bias; + std::shared_ptr layer_2_weight; + std::shared_ptr layer_2_bias; + std::shared_ptr post_norm; + + std::shared_ptr opt_padding; +}; + +struct UlmWeights { + std::vector ffs; + std::vector sas; + std::shared_ptr final_ln_scale; + std::shared_ptr softmax_linear; + std::shared_ptr softmax_bias; + + // Optional. Usually softmax_linear can be used as embedding, but sometimes we + // need to scale/transpose it. + std::shared_ptr token_embedding; + + static constexpr absl::string_view kKeyLoadedFromCache{"loaded_from_cache"}; +}; + +class UlmWeightsLoader { + public: + constexpr static absl::string_view kTransformerWeightPrefix{ + "params.lm.transformer.x_layers_"}; + constexpr static absl::string_view kFinalScaleFilename{ + "params.lm.final_ln.scale"}; + constexpr static absl::string_view kLogitsFfnBiasFilename{ + "params.lm.softmax.logits_ffn.bias.b"}; + constexpr static absl::string_view kLogitsFfnWeightFilename{ + "params.lm.softmax.logits_ffn.linear.w"}; + + UlmWeightsLoader(absl::string_view weight_path, const UlmParams& params) + : weight_path_(weight_path), params_(params) {} + virtual ~UlmWeightsLoader() = default; + + void SetBuilder(XnnGraphBuilder& builder) { builder_ = &builder; } + + virtual absl::StatusOr LoadWeights(); + + virtual absl::StatusOr LoadSelfAttention(int layer_id); + virtual absl::StatusOr LoadFeedForward(int layer_id); + + UlmParams& ulm_params() { return params_; } + const UlmParams& ulm_params() const { return params_; } + XnnGraphBuilder& builder() const { return *builder_; } + + protected: + // Find the files that matches prefix, then read from file. + virtual absl::StatusOr> LoadFromAbsPathPrefix( + absl::string_view prefix, const Tensor::DimsType& dims, + size_t dim_scale_if_any) const; + absl::StatusOr> LoadFromAbsPathPrefix( + absl::string_view prefix, const Tensor::DimsType& dims) const { + return LoadFromAbsPathPrefix(prefix, dims, 0); + } + + absl::StatusOr> TryCacheThenLoadSelfAttention( + absl::string_view filename_prefix) const; + absl::StatusOr> TryCacheThenLoadFeedForward( + absl::string_view filename_prefix, + std::optional dims = std::nullopt) const; + virtual absl::StatusOr> + TryCacheThenLoadWeightTranspose(absl::string_view filename_prefix, + Tensor::DimsType original_dims, + size_t original_dim_cale) const; + + std::string weight_path_; + UlmParams params_; + XnnGraphBuilder* builder_ = nullptr; +}; + +// Try: 1. load token embedding from cache; 2. fill token embedding by transpose +// softmax linear then scale; 3. dump token embedding to cache. +struct PrepareTokenEmbeddingDecorator { + static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&); +}; +struct TransposeSoftmaxWeightDecorator { + static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&); +}; +struct TransposeSelfAttentionWeightDecorator { + // If KQV weight are reshaped, ignore. + // If KQV weight are not properly shaped, load from cache if any, or build. + // If KQV weight are missing, try loading from cache path, or fail if missing. + static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&); +}; + +// Apply some decoration (in order) to the weights loaded by base class. +template +class UlmWeightsLoaderWith : public UlmWeightsLoader { + public: + UlmWeightsLoaderWith(absl::string_view weight_path, const UlmParams& params) + : UlmWeightsLoader(weight_path, params), + decorators_{Decorators::Decorate...} {} + + absl::StatusOr LoadWeights() override { + ASSIGN_OR_RETURN(auto result, UlmWeightsLoader::LoadWeights()); + for (const auto& decorator : decorators_) { + MP_RETURN_IF_ERROR(decorator(*this, result)); + } + return result; + } + + protected: + std::vector> + decorators_; +}; + +using DefaultUlmWeightsLoader = + UlmWeightsLoaderWith; + +// Generate weights with some random value. +class BenchmarkUlmWeightsLoader : public DefaultUlmWeightsLoader { + public: + explicit BenchmarkUlmWeightsLoader( + const UlmParams& params, xnn_datatype data_type = xnn_datatype_fp32); + + absl::StatusOr> TryCacheThenLoadWeightTranspose( + absl::string_view filename_prefix, Tensor::DimsType original_dims, + size_t original_dim_cale) const override; + + absl::StatusOr> LoadFromAbsPathPrefix( + absl::string_view prefix, const Tensor::DimsType& dims, + size_t dim_scale_if_any) const override; + + private: + xnn_datatype data_type_; + std::shared_ptr random_value_buffer_; +}; + +} // namespace xnn_utils +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc new file mode 100644 index 000000000..8407892af --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc @@ -0,0 +1,21 @@ +#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" + +namespace mediapipe { +namespace xnn_utils { + +std::vector FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels) { + std::vector out_array(max_seq_len * num_channels); + for (size_t ch_id = 0; ch_id < num_channels / 2; ++ch_id) { + auto timescale = std::pow(1e-4, 2.0 * ch_id / num_channels); + for (size_t seq_id = 0; seq_id < max_seq_len; ++seq_id) { + auto sinusoid_inp = seq_id * timescale; + out_array[seq_id * num_channels + ch_id] = cos(sinusoid_inp); + out_array[seq_id * num_channels + ch_id + num_channels / 2] = + sin(sinusoid_inp); + } + } + return out_array; +} + +} // namespace xnn_utils +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/utils.h b/mediapipe/tasks/cc/text/utils/xnn_utils/utils.h new file mode 100644 index 000000000..7aea30521 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/utils.h @@ -0,0 +1,61 @@ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_ + +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/status/statusor.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { +namespace xnn_utils { + +std::vector FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels); + +// expect_size_bytes == 0 means don't check size. +template +static absl::StatusOr> LoadBufferFromFile( + absl::string_view file_path, bool use_mmap = true, + size_t expect_size_bytes = 0) { + if (use_mmap) { + int fd = open(file_path.data(), O_RDONLY); + RET_CHECK_GE(fd, 0) << "open " << file_path << " failed"; + auto cleanup = absl::MakeCleanup([fd] { close(fd); }); + + const size_t size = lseek(fd, 0, SEEK_END); + if (expect_size_bytes) { + RET_CHECK_EQ(expect_size_bytes, size) + << "File size " << size << ", expected " << expect_size_bytes + << ", file path " << file_path; + } + + void* data = mmap(/*addr=*/nullptr, size, /*prot=*/PROT_READ, + /*flags=*/MAP_SHARED, fd, /*offset=*/0); + RET_CHECK_NE(data, MAP_FAILED); + RET_CHECK_NE(data, nullptr); + + return std::shared_ptr(static_cast(data), + [](auto* p) {}); + } else { + auto read_buffer = std::make_shared(); + MP_RETURN_IF_ERROR( + file::GetContents(file_path, read_buffer.get(), file::Defaults())); + + if (expect_size_bytes) { + RET_CHECK_EQ(expect_size_bytes, read_buffer->size()) + << "File size " << read_buffer->size() << ", expected " + << expect_size_bytes << ", file path " << file_path; + } + + return std::shared_ptr( + read_buffer, reinterpret_cast(read_buffer->data())); + } +} + +} // namespace xnn_utils +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc new file mode 100644 index 000000000..8d185ebd9 --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc @@ -0,0 +1,358 @@ +#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" +#include "third_party/XNNPACK/include/xnnpack.h" + +namespace mediapipe { +namespace xnn_utils { + +absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos) { + RET_CHECK_EQ(out_seg_pos.dims.size(), 2); + const size_t max_seq_len = out_seg_pos.dims[0]; + const size_t num_channels = out_seg_pos.dims[1]; + return out_seg_pos.LoadFromVec(FillXnnRoPEWeights(max_seq_len, num_channels)); +} + +std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { + os << "Tensor{dims=[" << tensor.dims << "], datatype=" << tensor.datatype + << ", num_elements=" << tensor.num_elements << "}"; + return os; +} + +std::ostream& operator<<(std::ostream& os, const QCTensor& tensor) { + os << "QCTensor{dims=[" << tensor.dims << "], dim_scale=" << tensor.dim_scale + << " datatype=" << tensor.datatype + << ", num_elements=" << tensor.num_elements << "}"; + return os; +} + +bool Tensor::operator==(const Tensor& other) const { + if (dims.size() != other.dims.size()) { + return false; + } else if (datatype != other.datatype) { + return false; + } else { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] != other.dims[i]) { + return false; + } + } + } + return 0 == memcmp(Data(), other.Data(), num_elements * ElementSize()); +} + +void Tensor::AllocateBufferIfNeeded() { + if (!flat_data) { + auto real_buffer = std::make_shared(); + real_buffer->reserve(num_elements * ElementSize() + XNN_EXTRA_BYTES); + flat_data = std::shared_ptr(real_buffer, real_buffer->data()); + } +} + +void* Tensor::Data() { + DCHECK(flat_data) + << "If this is weight, you may need to call one of the LoadFrom*()"; + return flat_data.get(); +} + +std::shared_ptr Tensor::Slice(DimsType offset) { + DCHECK(flat_data); + CHECK_EQ(offset.size(), dims.size()) << offset << " vs. " << dims; + // offset: [0, k, 0, 0], dims: [1, K, _, _]. dims before k must be 1. + bool found_non_zero_offset = false; + int index_k = -1; + for (int i = 0; i < dims.size(); ++i) { + if (found_non_zero_offset) { + DCHECK_EQ(offset[i], 0); + } else if (offset[i] != 0) { + found_non_zero_offset = true; + index_k = i; + } + } + DCHECK(found_non_zero_offset) << offset; + + return Slice(index_k, offset[index_k]); +} + +std::shared_ptr Tensor::Slice(size_t index, size_t offset) { + size_t num_elements_offset = 1; + DimsType new_dim = dims; + for (int i = 0; i < dims.size(); ++i) { + if (i < index) { + DCHECK_EQ(dims[i], 1); + } else if (i == index) { + num_elements_offset *= offset; + new_dim[i] = 1; + } else { + num_elements_offset *= dims[i]; + } + } + + auto result = std::make_shared(std::move(new_dim), datatype); + result->flat_data = std::shared_ptr( + flat_data, flat_data.get() + num_elements_offset * ElementSize()); + return result; +} + +Tensor& Tensor::Borrow(std::shared_ptr other, size_t element_offset) { + DCHECK_EQ(datatype, other->datatype); + DCHECK_EQ(dims.size(), other->dims.size()); + flat_data = std::shared_ptr( + other->flat_data, + other->flat_data.get() + element_offset * ElementSize()); + return *this; +} + +std::shared_ptr Tensor::View() { return View(dims); } + +std::shared_ptr Tensor::View(DimsType as_dims, size_t) { + auto result = std::make_shared(as_dims, datatype); + DCHECK_LE(result->num_elements, num_elements); + result->flat_data = flat_data; + return result; +} + +const void* Tensor::Data() const { return const_cast(this)->Data(); } + +absl::Status Tensor::DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags) { + uint32_t id; + RET_CHECK_EQ(xnn_status_success, + xnn_define_tensor_value(&subgraph, datatype, dims.size(), + dims.data(), /*data=*/nullptr, + /*external_id=*/tensor_id, flags, &id)); + if (tensor_id == XNN_INVALID_VALUE_ID) { + RET_CHECK_NE(id, XNN_INVALID_VALUE_ID); + tensor_id = id; + } else { + RET_CHECK_EQ(id, tensor_id); + } + return absl::OkStatus(); +} + +absl::Status Tensor::DefineAsInput(xnn_subgraph& subgraph) { + return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT); +} + +absl::Status Tensor::DefineAsOutput(xnn_subgraph& subgraph) { + return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_OUTPUT); +} + +absl::Status Tensor::DefineAsIntermediateTensor(xnn_subgraph& subgraph) { + RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID); + return DefineAsExternal(subgraph, 0); +} + +absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) { + RET_CHECK_EQ( + xnn_status_success, + xnn_define_tensor_value(&subgraph, datatype, dims.size(), dims.data(), + Data(), tensor_id, flags, &tensor_id)); + RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID); + return absl::OkStatus(); +} + +absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph) { + RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID); + return DefineWeight(subgraph, 0); +} + +absl::Status Tensor::DefineRope(xnn_subgraph& subgraph) { + RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID); + return DefineWeight(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT); +} + +absl::Status Tensor::LoadFromBuffer(const void* buffer) { + AllocateBufferIfNeeded(); + memcpy(Data(), buffer, num_elements * ElementSize()); + return absl::OkStatus(); +} + +absl::Status Tensor::LoadFromVec(const std::vector& data, + bool exact_match) { + AllocateBufferIfNeeded(); + if (exact_match) { + RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float)); + } + + memcpy(Data(), data.data(), data.size() * sizeof(float)); + + return absl::OkStatus(); +} + +absl::Status Tensor::LoadFromVec(std::vector&& data, bool exact_match) { + if (exact_match) { + RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float)); + } + + auto real_buffer = std::make_shared>(std::move(data)); + if (real_buffer->size() < num_elements) { + real_buffer->resize(num_elements); + } + flat_data = std::shared_ptr( + real_buffer, reinterpret_cast(real_buffer->data())); + + return absl::OkStatus(); +} + +absl::Status Tensor::DumpToBuffer(void* buffer) { + memcpy(buffer, Data(), num_elements * ElementSize()); + return absl::OkStatus(); +} + +absl::Status Tensor::DumpToVec(std::vector& out_data, bool exact_match) { + if (exact_match) { + RET_CHECK_EQ(num_elements * ElementSize(), out_data.size() * sizeof(float)); + } else { + out_data.resize(num_elements); + } + memcpy(out_data.data(), Data(), num_elements * ElementSize()); + return absl::OkStatus(); +} + +absl::Status Tensor::DumpToFile(absl::string_view file_path) { + return file::SetContents( + file_path, + absl::string_view(flat_data.get(), num_elements * ElementSize()), + file::Defaults()); +} + +absl::Status Tensor::LoadFromFile(absl::string_view file_path, bool use_mmap, + bool exact_match) { + const size_t expected_size_in_bytes = + exact_match ? num_elements * ElementSize() : 0; + + ASSIGN_OR_RETURN(flat_data, LoadBufferFromFile(file_path, use_mmap, + expected_size_in_bytes)); + return absl::OkStatus(); +} + +std::shared_ptr Tensor::Transpose() { + DCHECK_EQ(dims.size(), 2); + DimsType out_dims{dims.rbegin(), dims.rend()}; + auto result = std::make_shared(std::move(out_dims), datatype); + result->AllocateBufferIfNeeded(); + xnn_status s; + const DimsType perm{1, 0}; + if (datatype == xnn_datatype_fp32) { + s = xnn_run_transpose_nd_x32(Data(), result->Data(), dims.size(), + dims.data(), perm.data(), + /*flags=*/0, /*threadpool=*/nullptr); + } else { + LOG(FATAL) << "Need update to support new type"; + } + DCHECK_EQ(s, xnn_status_success); + return (s == xnn_status_success) ? result : nullptr; +} + +absl::StatusOr> Tensor::ConvertToF32() { + auto result = std::make_shared(dims, xnn_datatype_fp32); + MP_RETURN_IF_ERROR(result->LoadFromBuffer(Data())); + return result; +} + +absl::Status QCTensor::LoadFromFile(absl::string_view quantized_weight_filename, + absl::string_view scale_filename, + bool use_mmap, bool exact_match) { + size_t scale_element_size = dims[dim_scale]; + + ASSIGN_OR_RETURN(flat_data, + LoadBufferFromFile(quantized_weight_filename, use_mmap, + exact_match ? num_elements : 0)); + ASSIGN_OR_RETURN(scale_data, + LoadBufferFromFile( + scale_filename, use_mmap, + exact_match ? scale_element_size * sizeof(float) : 0)); + return absl::OkStatus(); +} + +absl::Status QCTensor::DumpToFile(absl::string_view file_path) { + MP_RETURN_IF_ERROR(file::SetContents( + file_path, + absl::string_view(flat_data.get(), num_elements * ElementSize()), + file::Defaults())); + return file::SetContents( + absl::StrCat(file_path, kQuantizedScaleSuffix), + absl::string_view(reinterpret_cast(scale_data.get()), + dims[dim_scale] * sizeof(float)), + file::Defaults()); +} + +absl::Status QCTensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) { + RET_CHECK_EQ( + xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + &subgraph, datatype, scale_data.get(), dims.size(), dim_scale, + dims.data(), Data(), XNN_INVALID_VALUE_ID, flags, &tensor_id)) + << *this; + RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID); + return absl::OkStatus(); +} + +void QCTensor::AllocateBufferIfNeeded() { + Tensor::AllocateBufferIfNeeded(); + if (!scale_data) { + auto real_buffer = std::make_shared>(); + real_buffer->reserve(dims[dim_scale]); + scale_data = std::shared_ptr(real_buffer, real_buffer->data()); + } +} + +std::shared_ptr QCTensor::Transpose() { + DCHECK_EQ(dims.size(), 2); + size_t channel_size = dims[dim_scale]; + DimsType out_dims{dims.rbegin(), dims.rend()}; + auto result = std::make_shared(std::move(out_dims), 1 - dim_scale); + result->AllocateBufferIfNeeded(); + memcpy(result->scale_data.get(), scale_data.get(), + channel_size * sizeof(float)); + xnn_status s; + const DimsType perm{1, 0}; + if (datatype == xnn_datatype_qcint8) { + s = xnn_run_transpose_nd_x8(Data(), result->Data(), dims.size(), + dims.data(), perm.data(), + /*flags=*/0, /*threadpool=*/nullptr); + } else { + LOG(FATAL) << "Need update to support new type"; + } + DCHECK_EQ(s, xnn_status_success); + return (s == xnn_status_success) ? result : nullptr; +} + +absl::StatusOr> QCTensor::ConvertToF32() { + auto result = std::make_shared(dims, xnn_datatype_fp32); + // TODO: proper implement. + LOG(WARNING) << "This is fake impl"; + MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false)); + return result; +} + +std::shared_ptr QCTensor::View(DimsType as_dims, + size_t dim_scale_if_any) { + auto result = std::make_shared(as_dims, dim_scale_if_any); + DCHECK_LE(result->num_elements, num_elements); + result->flat_data = flat_data; + result->scale_data = scale_data; + return result; +} + +} // namespace xnn_utils +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h b/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h new file mode 100644 index 000000000..10324ff4f --- /dev/null +++ b/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h @@ -0,0 +1,202 @@ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" +#include "third_party/XNNPACK/include/xnnpack.h" +#include "util/gtl/stl_logging.h" + +namespace mediapipe { +namespace xnn_utils { + +static constexpr absl::string_view kQuantizedScaleSuffix{"_quantized_scale"}; +static constexpr absl::string_view kSparsityParamsSuffix{"_sparsity_params"}; + +struct Tensor { + using DimsType = std::vector; + + explicit Tensor(DimsType in_dims, xnn_datatype datatype_ = xnn_datatype_fp32) + : dims(std::move(in_dims)), + num_elements(dims.empty() ? 0 + : std::accumulate(std::begin(dims), + std::end(dims), size_t(1), + std::multiplies())), + datatype(datatype_) {} + Tensor(Tensor&& other) = default; + + Tensor& operator=(const Tensor& other) = delete; + Tensor& operator=(Tensor&& other) = default; + + virtual ~Tensor() = default; + + bool operator==(const Tensor& other) const; + + void SetMetadata(absl::string_view key, int value) { metadata[key] = value; } + + std::optional GetMetadata(absl::string_view key) const { + if (metadata.contains(key)) { + return metadata.at(key); + } + return std::nullopt; + } + + // Read weights from file. + template + static absl::StatusOr> FromFile( + absl::string_view file_path, DimsType dims, bool use_mmap = true) { + auto result = std::make_shared(std::move(dims), xnn_datatype_); + + MP_RETURN_IF_ERROR( + result->LoadFromFile(file_path, use_mmap, /*exact_match=*/true)); + + return result; + } + + virtual absl::Status DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags); + absl::Status DefineAsInput(xnn_subgraph& subgraph); + absl::Status DefineAsOutput(xnn_subgraph& subgraph); + absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph); + virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags); + absl::Status DefineWeight(xnn_subgraph& subgraph); + absl::Status DefineRope(xnn_subgraph& subgraph); + + absl::Status LoadFromBuffer(const void* buffer); + absl::Status LoadFromVec(const std::vector& data, + bool exact_match = true); + absl::Status LoadFromVec(std::vector&& data, bool exact_match = true); + absl::Status LoadFromFile(absl::string_view file_path) { + return LoadFromFile(file_path, true, true); + } + virtual absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap, + bool exact_match); + + absl::Status DumpToBuffer(void* buffer); + absl::Status DumpToVec(std::vector& out_data, bool exact_match = true); + virtual absl::Status DumpToFile(absl::string_view file_path); + + // If ith offset is 0, view's ith dim equals to original ith dim, otherwise 1. + std::shared_ptr Slice(DimsType offset); + // Slice along the `index`th dimension, offset at this dimension. + std::shared_ptr Slice(size_t index, size_t offset); + + // Point the underline data to the borrowed tensor's data. + Tensor& Borrow(std::shared_ptr, size_t element_offset = 0); + std::shared_ptr View(); + virtual std::shared_ptr View(DimsType as_dims, + size_t dim_scale_if_any = 0); + + Tensor& MarkOutput() { + AllocateBufferIfNeeded(); + is_output_tensor = true; + return *this; + } + + virtual void* Data(); + const void* Data() const; + + template + T* DataAs() { + DCHECK_EQ(ElementSize(), sizeof(T)); + return static_cast(Data()); + } + template + const T* DataAs() const { + return static_cast(Data()); + } + + virtual std::shared_ptr Transpose(); + + virtual absl::StatusOr> ConvertToF32(); + + DimsType dims; + size_t num_elements = 0; + xnn_datatype datatype = xnn_datatype_invalid; + uint32_t tensor_id = XNN_INVALID_VALUE_ID; + + // shared_ptr to make TensorMetadata copyable. + std::shared_ptr flat_data; + + protected: + friend class XnnGraphBuilder; + friend class XnnGraph; + + // Actually allocate buffer unless necessary. + virtual void AllocateBufferIfNeeded(); + + virtual size_t ElementSize() const { return 4; } + + bool is_output_tensor = false; + + absl::flat_hash_map metadata; +}; + +std::ostream& operator<<(std::ostream& os, const Tensor& tensor); + +// Channelwise Quantized. +struct QCTensor : public Tensor { + explicit QCTensor(DimsType in_dims, size_t dim_scale_if_any) + : Tensor(std::move(in_dims)), dim_scale(dim_scale_if_any) { + datatype = xnn_datatype_qcint8; + CHECK_LT(dim_scale, 4); + } + + void AllocateBufferIfNeeded() override; + size_t ElementSize() const override { return 1; } + + virtual absl::Status LoadFromFile(absl::string_view quantized_weight_filename, + absl::string_view scale_filename, + bool use_mmap, bool exact_match); + // Append kQuantizedScaleSuffix to use as scale filename. + absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap, + bool exact_match) override { + return LoadFromFile(file_path, + absl::StrCat(file_path, kQuantizedScaleSuffix), + use_mmap, exact_match); + } + + absl::Status DumpToFile(absl::string_view file_path) override; + + absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags) override; + + std::shared_ptr Transpose() override; + + absl::StatusOr> ConvertToF32() override; + + std::shared_ptr View(DimsType as_dims, + size_t dim_scale_if_any) override; + + std::shared_ptr scale_data; + // Index of the dimension to scale. + size_t dim_scale; +}; + +std::ostream& operator<<(std::ostream& os, const QCTensor& tensor); + +absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos); + +} // namespace xnn_utils +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_