diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/BUILD b/mediapipe/tasks/cc/text/utils/xnn_utils/BUILD deleted file mode 100644 index 4b58cb8f6..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/BUILD +++ /dev/null @@ -1 +0,0 @@ -# Utilities needed to interacte with XNNPACK. diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc deleted file mode 100644 index 225b5985d..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.cc +++ /dev/null @@ -1,887 +0,0 @@ -#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "absl/types/source_location.h" -#include "file/base/helpers.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" -#include "third_party/XNNPACK/include/xnnpack.h" -#include "util/gtl/stl_logging.h" - -namespace mediapipe { -namespace xnn_utils { -namespace { - -// XNNPACK supports broadcasting, this function inferences the output shape -// based on input tensor shapes. -std::vector OutDimsForElementwiseOp(const Tensor& lhs, - const Tensor& rhs) { - DCHECK(!lhs.dims.empty()); - DCHECK(!rhs.dims.empty()); - std::vector lhs_dims_rev(lhs.dims.rbegin(), lhs.dims.rend()); - std::vector rhs_dims_rev(rhs.dims.rbegin(), rhs.dims.rend()); - DCHECK([&]() -> bool { - for (size_t i = 0; i < std::min(lhs_dims_rev.size(), rhs_dims_rev.size()); - ++i) { - if ((lhs_dims_rev[i] != rhs_dims_rev[i]) && (lhs_dims_rev[i] != 1) && - (rhs_dims_rev[i] != 1)) { - return false; - } - } - return true; - }()) << "lhs " - << lhs.dims << " rhs " << rhs.dims; - std::vector out_dims( - std::max(lhs_dims_rev.size(), rhs_dims_rev.size())); - for (int i = 0; i < out_dims.size(); ++i) { - if (lhs_dims_rev.size() <= i) { - out_dims[i] = rhs_dims_rev[i]; - } else if (rhs_dims_rev.size() <= i) { - out_dims[i] = lhs_dims_rev[i]; - } else { - out_dims[i] = lhs_dims_rev[i] == 1 ? rhs_dims_rev[i] : lhs_dims_rev[i]; - } - } - return std::vector(out_dims.rbegin(), out_dims.rend()); -} - -// If out_id is invalid, we need to allocate tensor for intermediate result. -// Otherwise, set out_id in out_metadata. -absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph, - uint32_t out_id, - Tensor& out_metadata) { - RET_CHECK_GT(out_metadata.dims.size(), 0); - if (out_id == XNN_INVALID_VALUE_ID) { - // The output is intermediate, thus allocate tensor. - MP_RETURN_IF_ERROR(out_metadata.DefineAsIntermediateTensor(*subgraph)); - } else { - out_metadata.tensor_id = out_id; - } - - return absl::OkStatus(); -} - -absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph, - Tensor& out_metadata) { - return MaybeAllocateIntermediateTensor(subgraph, out_metadata.tensor_id, - out_metadata); -} - -absl::Status AllocateIntermediateTensor(xnn_subgraph_t subgraph, - Tensor& out_metadata) { - return MaybeAllocateIntermediateTensor(subgraph, XNN_INVALID_VALUE_ID, - out_metadata); -} - -// 1.0/jax.nn.softplus(0.0) = 1.442695041 -// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1]) -void SoftPlus(size_t cnt, const std::vector& query_dims, float* weight, - float* scale) { - constexpr double r_softplus_0 = 1.442695041; - // softplus(x) = np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0) - // scale = softplus(per_dim_scale) / (sqrt(input.dims[-1]) * softplus(0)) - const double r_softplus_0_over_sqrt_d = - r_softplus_0 / std::sqrt(query_dims.back()); - for (int i = 0; i < cnt; ++i) { - scale[i] = log1p(exp(-abs(weight[i]))) + fmax(weight[i], 0.0f); - scale[i] *= r_softplus_0_over_sqrt_d; - } -} - -} // namespace - -absl::StatusOr> XnnGraphBuilder::Build( - std::unique_ptr runtime_configs) { - if (!runtime_configs) { - runtime_configs = std::make_unique(); - runtime_configs->xnn_num_threads = 1; - runtime_configs->xnn_profile = false; - } - VLOG(2) << "XnnGraphBuilder::Build() building..."; - auto build_begin = absl::Now(); - RET_CHECK_EQ(xnn_status_success, xnn_initialize(nullptr)); - - absl::flat_hash_set> output_tensors; - { - uint32_t cnt = input_tensors_.size(); - for (auto& t : interm_tensors_) { - if (t->is_output_tensor) { - RET_CHECK_EQ(t->tensor_id, XNN_INVALID_VALUE_ID); - t->tensor_id = cnt++; - output_tensors.insert(t); - } - } - for (auto& t : output_tensors) { - interm_tensors_.erase(t); - } - for (auto& t : rope_weigths_) { - interm_tensors_.erase(t); - t->tensor_id = cnt++; - } - } - - xnn_subgraph_t subgraph_ptr = nullptr; - RET_CHECK_EQ(xnn_status_success, - xnn_create_subgraph( - /*external_value_ids=*/input_tensors_.size() + - output_tensors.size() + rope_weigths_.size(), - /*flags=*/0, &subgraph_ptr)); - RET_CHECK_NE(subgraph_ptr, nullptr); - - XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph}; - - for (auto& input : input_tensors_) { - MP_RETURN_IF_ERROR(input->DefineAsInput(*subgraph)); - } - for (auto& output : output_tensors) { - MP_RETURN_IF_ERROR(output->DefineAsOutput(*subgraph)); - } - { - for (auto& t : rope_weigths_) { - MP_RETURN_IF_ERROR(t->DefineRope(*subgraph)); - } - } - - for (auto& [loc, step] : build_steps_) { - if (auto s = step(subgraph.get()); !s.ok()) { - s.AddSourceLocation(loc); - return s; - } - } - - XnnGraph result(std::move(subgraph), std::move(runtime_configs)); - result.input_tensors_ = std::move(input_tensors_); - result.output_tensors_ = std::move(output_tensors); - result.interm_tensors_ = std::move(interm_tensors_); - - VLOG(2) << "XnnGraphBuilder::Build() creating runtime..."; - auto create_begin = absl::Now(); - MP_RETURN_IF_ERROR(result.CreateRuntime()); - VLOG(2) << "XnnGraphBuilder::Build() setting up runtime..."; - auto setup_begin = absl::Now(); - MP_RETURN_IF_ERROR(result.SetupRuntime()); - - auto end = absl::Now(); - VLOG(2) << "XnnGraphBuilder::Build() done build, Total " << end - build_begin - << ", create runtime " << setup_begin - create_begin - << ", setup runtime " << end - setup_begin; - return std::make_unique(std::move(result)); -} - -absl::StatusOr> XnnGraphBuilder::NewInput( - Tensor::DimsType dims, absl::SourceLocation loc) { - auto t = std::make_shared(std::move(dims), data_type_); - t->AllocateBufferIfNeeded(); - t->tensor_id = input_tensors_.size(); - input_tensors_.insert(t); - return t; -} - -absl::StatusOr> XnnGraphBuilder::NewWeight( - absl::string_view file_path, Tensor::DimsType dims, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto t, NewWeight(std::move(dims))); - MP_RETURN_IF_ERROR(t->LoadFromFile(file_path)); - return t; -} - -absl::StatusOr> XnnGraphBuilder::NewWeight( - Tensor::DimsType dims, absl::SourceLocation loc) { - auto t = std::make_shared(std::move(dims), data_type_); - NewWeight(t, loc); - return t; -} - -void XnnGraphBuilder::NewWeight(std::shared_ptr t, - absl::SourceLocation loc) { - build_steps_.push_back( - {loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status { - if (interm_tensors_.contains(t)) { - MP_RETURN_IF_ERROR(t->DefineWeight(*subgraph)); - } - return absl::OkStatus(); - }}); - - interm_tensors_.insert(t); -} - -absl::StatusOr> XnnGraphBuilder::IntermediateTensor( - Tensor::DimsType dims, absl::SourceLocation loc) { - auto t = std::make_shared(std::move(dims), data_type_); - - build_steps_.push_back( - {loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status { - // Could be moved to output tensors, thus need check. - if (interm_tensors_.contains(t)) { - return AllocateIntermediateTensor(subgraph, *t); - } - return absl::OkStatus(); - }}); - - interm_tensors_.insert(t); - return t; -} - -absl::StatusOr> XnnGraphBuilder::Reshape( - std::shared_ptr input, Tensor::DimsType new_dims, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims))); - RET_CHECK_EQ(input->num_elements, output->num_elements) - << "otherwise reshape does not make sense."; - - build_steps_.push_back( - {loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( - subgraph, output->tensor_id, *output)); - - RET_CHECK_EQ(xnn_status_success, - xnn_define_static_reshape( - subgraph, output->dims.size(), output->dims.data(), - input->tensor_id, output->tensor_id, /*flags=*/0)); - return absl::OkStatus(); - }}); - return output; -} - -absl::StatusOr> XnnGraphBuilder::FullConn( - std::shared_ptr input, std::shared_ptr weight, - std::shared_ptr bias, FullConnParams params, - absl::SourceLocation loc) { - const auto& input_dim = input->dims; - const auto& weight_dim = weight->dims; - DCHECK_GT(input_dim.size(), 1); - DCHECK_GE(weight_dim.size(), 2); - if (weight_dim.size() == 3) { - RET_CHECK_EQ(weight_dim[0], 1); - } else if (weight_dim.size() == 4) { - RET_CHECK_EQ(weight_dim[0], 1); - RET_CHECK_EQ(weight_dim[1], 1); - } - if (bias) { - RET_CHECK_LE(bias->dims.size(), 1); - } - - Tensor::DimsType out_dims = input_dim; - // Not considering reshape 2D - if (params.transpose) { - RET_CHECK_EQ(weight_dim.size(), 2) << "otherwise change following line"; - RET_CHECK_EQ(input_dim.back(), *(weight_dim.end() - 2)); - out_dims.back() = weight_dim.back(); - } else { - RET_CHECK_EQ(input_dim.back(), weight_dim.back()); - out_dims.pop_back(); - for (size_t i = 0; i < weight_dim.size() - 1; ++i) { - // NHD . BTD -> NHBT - out_dims.push_back(weight_dim[i]); - } - } - ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(out_dims))); - - build_steps_.push_back( - {loc, - [this, input, weight, bias, params, - output](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( - subgraph, output->tensor_id, *output)); - - RET_CHECK_EQ( - xnn_status_success, - xnn_define_fully_connected( - subgraph, params.out_min, params.out_max, input->tensor_id, - weight->tensor_id, - bias ? bias->tensor_id : XNN_INVALID_VALUE_ID, - output->tensor_id, - /*flags=*/params.transpose ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0)); - - return absl::OkStatus(); - }}); - return output; -} - -absl::StatusOr> XnnGraphBuilder::Permute( - std::shared_ptr input, Tensor::DimsType permute, - absl::SourceLocation loc) { - RET_CHECK_EQ(input->dims.size(), permute.size()); - const auto& old_dims = input->dims; - std::vector new_dims; - for (size_t i = 0; i < permute.size(); ++i) { - new_dims.push_back(old_dims[permute[i]]); - } - ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims))); - - build_steps_.push_back( - {loc, - [this, permute, input, output](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - - RET_CHECK_EQ(xnn_status_success, - xnn_define_static_transpose( - subgraph, permute.size(), permute.data(), - input->tensor_id, output->tensor_id, /*flags=*/0)); - return absl::OkStatus(); - }}); - return output; -} - -absl::StatusOr> XnnGraphBuilder::Square( - std::shared_ptr input, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); - - build_steps_.push_back( - {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( - subgraph, output->tensor_id, *output)); - RET_CHECK_EQ( - xnn_status_success, - xnn_define_square(subgraph, input->tensor_id, output->tensor_id, - /*flags=*/0)); - return absl::Status(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::Softmax( - std::shared_ptr input, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); - - build_steps_.push_back( - {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( - subgraph, output->tensor_id, *output)); - RET_CHECK_EQ( - xnn_status_success, - xnn_define_softmax(subgraph, input->tensor_id, output->tensor_id, - /*flags=*/0)); - return absl::Status(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::SquareRoot( - std::shared_ptr input, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); - - build_steps_.push_back( - {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( - subgraph, output->tensor_id, *output)); - RET_CHECK_EQ(xnn_status_success, - xnn_define_square_root(subgraph, input->tensor_id, - output->tensor_id, - /*flags=*/0)); - return absl::Status(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::AvgLastDim( - std::shared_ptr input, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto before_reshape, - IntermediateTensor(Tensor::DimsType{input->dims.begin(), - input->dims.end() - 1})); - build_steps_.push_back( - {loc, - [this, input, before_reshape](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor( - subgraph, before_reshape->tensor_id, *before_reshape)); - size_t reduction_axis = input->dims.size() - 1; - RET_CHECK_EQ( - xnn_status_success, - xnn_define_static_mean(subgraph, 1, &reduction_axis, - input->tensor_id, before_reshape->tensor_id, - /*flags=*/0)); - return absl::OkStatus(); - }}); - - Tensor::DimsType new_dims = input->dims; - new_dims.back() = 1; - return Reshape(before_reshape, std::move(new_dims)); -} - -absl::StatusOr> XnnGraphBuilder::Rms( - std::shared_ptr input, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto sqr_out, Square(input, loc)); - - ASSIGN_OR_RETURN(auto mean_out, AvgLastDim(sqr_out, loc)); - - return SquareRoot(mean_out, loc); -} - -absl::StatusOr> XnnGraphBuilder::RmsNorm( - std::shared_ptr input, std::shared_ptr scale, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto rms_out, Rms(input)); - - ASSIGN_OR_RETURN(auto clamped_rms, Clamp(rms_out, {.out_min = 1e-6})); - - // div_out = input / rms - ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms)); - - // div_out * (1 + scale) = div_out + div_out * scale - ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale)); - - return ElementAdd(div_out, normed_div_out); -} - -absl::StatusOr> XnnGraphBuilder::ElementAdd( - std::shared_ptr lhs, float rhs, ClampParams params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1})); - MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector({rhs}))); - - return ElementAdd(lhs, rhs_tensor, params, loc); -} - -absl::StatusOr> XnnGraphBuilder::ElementAdd( - std::shared_ptr lhs, std::shared_ptr rhs, - ClampParams params, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, - IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs))); - - build_steps_.push_back( - {loc, - [this, lhs, rhs, output, - params](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - RET_CHECK_EQ(xnn_status_success, - xnn_define_add2(subgraph, params.out_min, params.out_max, - lhs->tensor_id, rhs->tensor_id, - output->tensor_id, /*flags=*/0)); - return absl::OkStatus(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::ElementMul( - std::shared_ptr lhs, float rhs, ClampParams params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1})); - MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector({rhs}))); - - return ElementMul(lhs, rhs_tensor, params, loc); -} - -absl::StatusOr> XnnGraphBuilder::ElementMul( - std::shared_ptr lhs, std::shared_ptr rhs, - ClampParams params, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, - IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs))); - - build_steps_.push_back( - {loc, - [this, lhs, rhs, output, - params](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - RET_CHECK_EQ( - xnn_status_success, - xnn_define_multiply2(subgraph, params.out_min, params.out_max, - lhs->tensor_id, rhs->tensor_id, - output->tensor_id, /*flags=*/0)); - return absl::OkStatus(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::ElementDiv( - std::shared_ptr lhs, float rhs, ClampParams params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1})); - MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector({rhs}))); - - return ElementDiv(lhs, rhs_tensor, params, loc); -} - -absl::StatusOr> XnnGraphBuilder::ElementDiv( - std::shared_ptr lhs, std::shared_ptr rhs, - ClampParams params, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, - IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs))); - - build_steps_.push_back( - {loc, - [this, lhs, rhs, output, - params](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - RET_CHECK_EQ( - xnn_status_success, - xnn_define_divide(subgraph, params.out_min, params.out_max, - lhs->tensor_id, rhs->tensor_id, - output->tensor_id, /*flags=*/0)); - return absl::OkStatus(); - }}); - - return output; -} - -// TODO: write an op? -absl::StatusOr> XnnGraphBuilder::PerDimScale( - std::shared_ptr input, std::shared_ptr per_dim_scale, - absl::SourceLocation loc) { - // input: B T N H - // 1/softplus(0) = 1.442695041 - // scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1]) - // query = query * scale - const auto& input_dim = input->dims; - DCHECK_GE(input_dim.size(), 1); - const size_t H = input_dim.back(); - - if (!per_dim_scale_cache_.contains(H) || - !per_dim_scale_cache_[H].contains(per_dim_scale.get())) { - ASSIGN_OR_RETURN(auto cached_pds, NewWeight(per_dim_scale->dims)); - - auto* pds_in = static_cast(per_dim_scale->Data()); - std::vector pds_scaled(per_dim_scale->num_elements); - SoftPlus(per_dim_scale->num_elements, input_dim, pds_in, pds_scaled.data()); - MP_RETURN_IF_ERROR(cached_pds->LoadFromVec(std::move(pds_scaled))); - per_dim_scale_cache_[H][per_dim_scale.get()] = cached_pds; - } - - return ElementMul(input, per_dim_scale_cache_[H][per_dim_scale.get()]); -} - -absl::StatusOr> XnnGraphBuilder::Rope( - std::shared_ptr input, std::shared_ptr segment_pos, - absl::SourceLocation loc) { - // TODO: seg_pos should not be weight. - rope_weigths_.insert(segment_pos); - - const auto& input_dim = input->dims; - const auto& segment_pos_dim = segment_pos->dims; - // B T N H - RET_CHECK_EQ(input_dim.size(), 4) << "xnn requirement"; - // S H - RET_CHECK_EQ(segment_pos_dim.size(), 2) << "xnn requirement"; - - ASSIGN_OR_RETURN(auto output, IntermediateTensor(input_dim)); - - const auto input_seq_size = input_dim[1]; - RET_CHECK_LE(input_seq_size, segment_pos_dim[0]); - const auto head_dim_H = input_dim[3]; - RET_CHECK_EQ(head_dim_H, segment_pos_dim[1]); - - build_steps_.push_back( - {loc, - [this, input, output, segment_pos, - input_seq_size](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - RET_CHECK_EQ( - xnn_status_success, - xnn_define_rope(subgraph, input_seq_size, input->tensor_id, - segment_pos->tensor_id, output->tensor_id, - /*flags=*/0)); - return absl::OkStatus(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::BatchMatMul( - std::shared_ptr input, std::shared_ptr weight, - FullConnParams params, absl::SourceLocation loc) { - const auto& lhs_dim = input->dims; - const auto& rhs_dim = weight->dims; - - // [B, N, T, H] . [B, N, S, H], N == 12, B == 1 - DCHECK_EQ(lhs_dim.size(), 4); - DCHECK_EQ(rhs_dim.size(), 4); - DCHECK_EQ(lhs_dim.back(), rhs_dim.back()); - DCHECK_EQ(lhs_dim.back(), rhs_dim.back()); - constexpr size_t num_slices = 12; - DCHECK_EQ(lhs_dim[1], num_slices); - DCHECK_EQ(rhs_dim[1], num_slices); - const size_t S = rhs_dim[2]; - const size_t T = lhs_dim[2]; - const size_t batch_size = lhs_dim[0] * lhs_dim[1]; - DCHECK_EQ(batch_size, rhs_dim[0] * rhs_dim[1]); - DCHECK_EQ(batch_size, 12); - - ASSIGN_OR_RETURN(auto output, IntermediateTensor({1, 12, T, S})); - - build_steps_.push_back( - {loc, [input, output, weight](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - - RET_CHECK_EQ(xnn_status_success, - xnn_define_batch_matrix_multiply( - subgraph, input->tensor_id, weight->tensor_id, - output->tensor_id, /*flags=*/0)); - - return absl::OkStatus(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::Tanh( - std::shared_ptr input, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); - - build_steps_.push_back( - {loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - - RET_CHECK_EQ(xnn_status_success, - xnn_define_tanh(subgraph, input->tensor_id, - output->tensor_id, /*flags=*/0)); - return absl::OkStatus(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::CapTanh( - std::shared_ptr input, float cap, absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto div, ElementDiv(input, cap)); - ASSIGN_OR_RETURN(auto tanh, Tanh(div)); - return ElementMul(tanh, cap); -} - -absl::StatusOr> XnnGraphBuilder::DotAttention( - std::shared_ptr query_proj, std::shared_ptr key_proj, - std::shared_ptr value_proj, std::shared_ptr atten_mask, - std::shared_ptr per_dim_scale, absl::SourceLocation loc) { - // BTNH - ASSIGN_OR_RETURN(auto query_after_scale, - PerDimScale(query_proj, per_dim_scale)); - - // Dot similarity - // BTNH -> BNTH - ASSIGN_OR_RETURN(auto query_permuted, - Permute(query_after_scale, {0, 2, 1, 3})); - // BSNH -> BNSH - ASSIGN_OR_RETURN(auto key_permuted, Permute(key_proj, {0, 2, 1, 3})); - // einsum(BNTH.BNSH -> BNTS) - ASSIGN_OR_RETURN(auto logits, BatchMatMul(query_permuted, key_permuted)); - - // Cap, mask - ASSIGN_OR_RETURN(auto cap_logits, CapTanh(logits, 50)); - ASSIGN_OR_RETURN(auto padded_logits, ElementAdd(atten_mask, cap_logits)); - ASSIGN_OR_RETURN(auto probs, Softmax(padded_logits)); - ASSIGN_OR_RETURN(auto value_permuted, Permute(value_proj, {0, 2, 3, 1})); - - // Outcome - // BNTS.BNHS -> BNTH - ASSIGN_OR_RETURN(auto outcome_before_permute, - BatchMatMul(probs, value_permuted)); - // [B, N, T, H] -> BTNH - return Permute(outcome_before_permute, {0, 2, 1, 3}); -} - -absl::StatusOr> XnnGraphBuilder::SelfAttentionProj( - std::shared_ptr input, std::shared_ptr weight, - absl::SourceLocation loc) { - const auto& input_dim = input->dims; - const auto& weight_dim = weight->dims; - size_t N = 0, H = 0; - RET_CHECK_EQ(input_dim.size(), 3) << "BTD"; - - std::optional reshaped_N = - weight->GetMetadata(kKeySelfAttentionReshapedWeight); - RET_CHECK(reshaped_N && *reshaped_N) - << "We rely on " << kKeySelfAttentionReshapedWeight << " to get N"; - RET_CHECK_EQ(weight_dim.size(), 2) << "NH,D"; - N = *reshaped_N; - H = weight_dim[0] / N; - - // out: B,T,NH - ASSIGN_OR_RETURN(auto proj, MatMul(input, weight)); - - // B,T,NH -> B,T,N,H - return Reshape(proj, {input_dim[0], input_dim[1], N, H}); -} - -absl::Status XnnGraph::CreateRuntime() { - RET_CHECK_EQ(runtime_.get(), nullptr); - xnn_runtime_t runtime_ptr = nullptr; - uint32_t flags = 0; - if (runtime_configs_->xnn_profile) { - flags |= XNN_FLAG_BASIC_PROFILING; - - if (!runtime_configs_->xnn_profile_csv.empty()) { - MP_RETURN_IF_ERROR(file::SetContents(runtime_configs_->xnn_profile_csv, - "node_id; time(us); op_name\n", - file::Defaults())); - } - } - pthreadpool_t threadpool = - pthreadpool_create(runtime_configs_->xnn_num_threads); - threadpool_ = XnnThreadpoolPtr{threadpool, pthreadpool_destroy}; - - RET_CHECK_EQ(xnn_status_success, - xnn_create_runtime_v2(owned_subgraph_.get(), threadpool, flags, - &runtime_ptr)); - RET_CHECK_NE(runtime_ptr, nullptr); - runtime_ = XnnRuntimePtr{runtime_ptr, xnn_delete_runtime}; - - return absl::OkStatus(); -} - -absl::Status XnnGraph::SetupRuntime() { - { - VLOG(3) << "input size " << input_tensors_.size(); - VLOG(3) << "output size " << output_tensors_.size(); - VLOG(3) << "rope size " << rope_weigths_.size(); - externals_.clear(); - // Init external - for (const auto& input : input_tensors_) { - VLOG(3) << "input id " << input->tensor_id; - externals_.push_back(xnn_external_value{input->tensor_id, input->Data()}); - } - for (const auto& output : output_tensors_) { - VLOG(3) << "output id " << output->tensor_id; - externals_.push_back( - xnn_external_value{output->tensor_id, output->Data()}); - } - for (const auto& t : rope_weigths_) { - VLOG(3) << "rope id " << t->tensor_id; - } - } - RET_CHECK_EQ( - xnn_status_success, - xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data())); - return absl::OkStatus(); -} - -absl::Status XnnGraph::Run() { - RET_CHECK(runtime_); - - RET_CHECK_EQ(xnn_status_success, xnn_invoke_runtime(runtime_.get())); - - if (runtime_configs_->xnn_profile) { - size_t required_size = 0; - - // xnn_get_runtime_profiling_info is called twice. The first time it sets - // required_size to the required size of the buffer to store the result and - // returns xnn_status_out_of_memory. The second time it writes the result to - // the buffer provided that the buffer is large enough and returns - // xnn_status_success. - xnn_status status = xnn_get_runtime_profiling_info( - runtime_.get(), xnn_profile_info_operator_name, /*param_value_size*/ 0, - /*param_value*/ nullptr, &required_size); - std::vector operator_names; - if (status == xnn_status_out_of_memory) { - operator_names.resize(required_size); - status = xnn_get_runtime_profiling_info( - runtime_.get(), xnn_profile_info_operator_name, operator_names.size(), - operator_names.data(), &required_size); - } - RET_CHECK_EQ(status, xnn_status_success); - size_t num_operators; - status = xnn_get_runtime_profiling_info( - runtime_.get(), xnn_profile_info_num_operators, sizeof(num_operators), - &num_operators, &required_size); - RET_CHECK_EQ(status, xnn_status_success); - status = xnn_get_runtime_profiling_info( - runtime_.get(), xnn_profile_info_operator_timing, - /*param_value_size*/ 0, - /*param_value*/ nullptr, &required_size); - std::vector operator_timings; - if (status == xnn_status_out_of_memory) { - operator_timings.resize(required_size / sizeof(uint64_t)); - status = xnn_get_runtime_profiling_info( - runtime_.get(), xnn_profile_info_operator_timing, - operator_timings.size() * sizeof(uint64_t), operator_timings.data(), - &required_size); - } - RET_CHECK_EQ(status, xnn_status_success); - const char* operator_name = nullptr; - size_t name_len = 0; - std::stringstream ss; - for (size_t node_index = 0; node_index < num_operators; ++node_index) { - operator_name = &operator_names[name_len]; - name_len += strlen(operator_name) + 1; - VLOG(2) << "XnnGraphBuilder::Profile() node_index: " << node_index - << ", time: " << operator_timings[node_index] << " us, " - << operator_name << "\n"; - if (!runtime_configs_->xnn_profile_csv.empty()) { - // Use ';' instead of ',' because operator_name contains comma. - ss << node_index << "; " << operator_timings[node_index] << "; " - << operator_name << "\n"; - } - } - if (!runtime_configs_->xnn_profile_csv.empty()) { - MP_RETURN_IF_ERROR(file::AppendStringToFile( - runtime_configs_->xnn_profile_csv, ss.str(), file::Defaults())); - } - } - - return absl::OkStatus(); -} - -absl::StatusOr> XnnGraphBuilder::Clamp( - std::shared_ptr input, ClampParams params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims)); - - build_steps_.push_back( - {loc, - [this, input, output, params](xnn_subgraph_t subgraph) -> absl::Status { - MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output)); - - RET_CHECK_EQ(xnn_status_success, - xnn_define_clamp(subgraph, params.out_min, params.out_max, - input->tensor_id, output->tensor_id, - /*flags=*/0)); - return absl::OkStatus(); - }}); - - return output; -} - -absl::StatusOr> XnnGraphBuilder::Gelu( - std::shared_ptr input, absl::SourceLocation loc) { - // x^2 - ASSIGN_OR_RETURN(auto sqr_out, Square(input)); - - // 0.044715 * x^2 - ASSIGN_OR_RETURN(auto sqr_4471, ElementMul(sqr_out, 0.044715)); - - // 1 + 0.044715 * x^2 - ASSIGN_OR_RETURN(auto sqr_4471_1, ElementAdd(sqr_4471, 1.0f)); - - // x + 0.044715 * x^3 - ASSIGN_OR_RETURN(auto x_cube_4471, ElementMul(sqr_4471_1, input)); - - constexpr float sqrt_2_over_pi = 0.7978845608; - ASSIGN_OR_RETURN(auto sqrt_2_over_pi_x_cube_4471, - ElementMul(x_cube_4471, sqrt_2_over_pi)); - - // tanh(x + 0.044715 * x^3) - ASSIGN_OR_RETURN(auto tanh_x_cube_4471, Tanh(sqrt_2_over_pi_x_cube_4471)); - - // 1 + tanh(x + 0.044715 * x^3) - ASSIGN_OR_RETURN(auto tanh_x_cube_4471_1, ElementAdd(tanh_x_cube_4471, 1.0f)); - - // 0.5 * (1 + [tanh(x + 0.044715 * x^3)]) - ASSIGN_OR_RETURN(auto cdf, ElementMul(tanh_x_cube_4471_1, 0.5)); - - return ElementMul(input, cdf); -} - -} // namespace xnn_utils -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h b/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h deleted file mode 100644 index 24b7520ba..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h +++ /dev/null @@ -1,288 +0,0 @@ -#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_ -#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/source_location.h" -#include "file/base/helpers.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" -#include "third_party/XNNPACK/include/xnnpack.h" - -namespace mediapipe { -namespace xnn_utils { - -using XnnSubgraphPtr = - std::unique_ptr; -using XnnRuntimePtr = - std::unique_ptr; -using XnnThreadpoolPtr = - std::unique_ptr; - -struct ClampParams { - float out_min = -std::numeric_limits::infinity(); - float out_max = std::numeric_limits::infinity(); -}; - -struct FullConnParams : public ClampParams { - bool transpose = false; -}; - -struct RuntimeConfigs { - bool xnn_profile; - std::string xnn_profile_csv; - size_t xnn_num_threads; -}; - -class XnnGraph; - -// XnnGraphBuilder is used to construct XnnGraph (through Build()). Once a -// XnnGraph is constructed, it can run for multiple times. -class XnnGraphBuilder { - public: - static constexpr absl::string_view kKeySelfAttentionReshapedWeight{ - "self_attention_reshaped_weight_N"}; - - explicit XnnGraphBuilder(xnn_datatype data_type = xnn_datatype_fp32) - : data_type_(data_type) {} - virtual ~XnnGraphBuilder() = default; - - absl::StatusOr> Build( - std::unique_ptr runtime_configs = nullptr); - - // New input or output tensor. - absl::StatusOr> NewInput( - Tensor::DimsType dims, - absl::SourceLocation loc = absl::SourceLocation::current()); - - // New static weight, populate value before Build() - absl::StatusOr> NewWeight( - Tensor::DimsType dims, - absl::SourceLocation loc = absl::SourceLocation::current()); - absl::StatusOr> NewWeight( - absl::string_view file_path, Tensor::DimsType dims, - absl::SourceLocation loc = absl::SourceLocation::current()); - void NewWeight(std::shared_ptr t, - absl::SourceLocation loc = absl::SourceLocation::current()); - - // Element wise square. - absl::StatusOr> Square( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> SquareRoot( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Gelu( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Clamp( - std::shared_ptr input, ClampParams params, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Tanh( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - // logits = cap * jnp.tanh(logits / cap) - absl::StatusOr> CapTanh( - std::shared_ptr input, float cap, - absl::SourceLocation loc = absl::SourceLocation::current()); - - // Average over last dimension, keep num of dims same. - absl::StatusOr> AvgLastDim( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Rms( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> RmsNorm( - std::shared_ptr input, std::shared_ptr scale, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Reshape( - std::shared_ptr input, Tensor::DimsType new_dims, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Permute( - std::shared_ptr input, Tensor::DimsType permute, - absl::SourceLocation loc = absl::SourceLocation::current()); - - // input: [B * I] - // filter: [O * I], [I * O] if transpose - // return: [B * O] - absl::StatusOr> MatMul( - std::shared_ptr input, std::shared_ptr weight, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return MatMul(input, weight, FullConnParams(), loc); - } - - absl::StatusOr> MatMul( - std::shared_ptr input, std::shared_ptr weight, - FullConnParams params, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return FullConn(input, weight, nullptr, params, loc); - } - - absl::StatusOr> BatchMatMul( - std::shared_ptr input, std::shared_ptr weight, - FullConnParams params = FullConnParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> FullConn( - std::shared_ptr input, std::shared_ptr weight, - std::shared_ptr bias, - absl::SourceLocation loc = absl::SourceLocation::current()) { - return FullConn(input, weight, bias, FullConnParams(), loc); - } - - absl::StatusOr> FullConn( - std::shared_ptr input, std::shared_ptr weight, - std::shared_ptr bias, FullConnParams params, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Softmax( - std::shared_ptr input, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> SelfAttentionProj( - std::shared_ptr input, std::shared_ptr weight, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> ElementAdd( - std::shared_ptr lhs, std::shared_ptr rhs, - ClampParams params = ClampParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> ElementAdd( - std::shared_ptr lhs, float rhs, - ClampParams params = ClampParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> ElementMul( - std::shared_ptr lhs, std::shared_ptr rhs, - ClampParams params = ClampParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> ElementMul( - std::shared_ptr lhs, float rhs, - ClampParams params = ClampParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> ElementDiv( - std::shared_ptr lhs, std::shared_ptr rhs, - ClampParams params = ClampParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> ElementDiv( - std::shared_ptr lhs, float rhs, - ClampParams params = ClampParams(), - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> Rope( - std::shared_ptr input, std::shared_ptr segment_pos, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> PerDimScale( - std::shared_ptr input, std::shared_ptr per_dim_scale, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> DotAttention( - std::shared_ptr query_proj, std::shared_ptr key_proj, - std::shared_ptr value_proj, std::shared_ptr atten_mask, - std::shared_ptr per_dim_scale, - absl::SourceLocation loc = absl::SourceLocation::current()); - - protected: - absl::StatusOr> IntermediateTensor( - Tensor::DimsType dims, - absl::SourceLocation loc = absl::SourceLocation::current()); - - const xnn_datatype data_type_; - - std::vector>> - build_steps_; - - absl::flat_hash_set> input_tensors_; - absl::flat_hash_set> interm_tensors_; - - // TODO: fix this. - // This is sort of bug that the weights used for rope has to be defined with - // EXTERNAL flag, but with id out of the external range. - absl::flat_hash_set> rope_weigths_; - - // Caches - absl::flat_hash_map< - size_t /*dim*/, - absl::flat_hash_map>> - per_dim_scale_cache_; -}; - -class XnnGraph { - public: - XnnGraph(XnnSubgraphPtr subgraph, - std::unique_ptr runtime_configs) - : owned_subgraph_(std::move(subgraph)), - runtime_configs_(std::move(runtime_configs)) { - DCHECK(runtime_configs_); - } - XnnGraph(XnnGraph&& other) = default; - virtual ~XnnGraph() = default; - - // xnn_subgraph should be created with same size. - virtual absl::Status Run(); - - protected: - friend class XnnGraphBuilder; - - absl::Status CreateRuntime(); - absl::Status SetupRuntime(); - - XnnSubgraphPtr owned_subgraph_; - - absl::flat_hash_map avg_cache_; - absl::flat_hash_map cap_tanh_cache_; - - // Runtime - std::unique_ptr runtime_configs_; - XnnRuntimePtr runtime_{nullptr, xnn_delete_runtime}; - std::vector externals_; - - XnnThreadpoolPtr threadpool_{nullptr, pthreadpool_destroy}; - - absl::flat_hash_set> input_tensors_; - absl::flat_hash_set> output_tensors_; - // TODO: see above - absl::flat_hash_set> rope_weigths_; - - absl::flat_hash_set> interm_tensors_; -}; - -} // namespace xnn_utils -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc deleted file mode 100644 index f60e53394..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.cc +++ /dev/null @@ -1,475 +0,0 @@ -#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h" - -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/text_generator/calculators/preprocessor_util.h" -#include "mediapipe/tasks/cc/text/text_generator/calculators/sampler_util.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" -#include "util/gtl/stl_logging.h" - -namespace mediapipe { -namespace xnn_utils { -namespace { - -absl::StatusOr> ApplyFinalProj( - std::shared_ptr inter_layer, const UlmWeights& weights, - XnnGraphBuilder& builder) { - return builder.FullConn(inter_layer, weights.softmax_linear, - weights.softmax_bias); -} - -} // namespace - -class OneTokenUlm : public Ulm { - public: - OneTokenUlm(std::unique_ptr full_ulm, XnnGraph&& other) - : Ulm(std::move(other)), full_ulm_(std::move(full_ulm)) {} - ~OneTokenUlm() override = default; - - absl::Status InitInputTokens(const std::vector& input_ids) override { - prev_ids_ = input_ids; - MP_RETURN_IF_ERROR(full_ulm_->InitInputTokens(input_ids)); - // prev_id.size - 1 is the output. - return full_ulm_->Run(); - } - - absl::Status GetNextToken(std::vector* output_ids) override { - size_t decode_step = prev_ids_.size() - 1; - VLOG(2) << "Decode step " << decode_step; - - if (decode_step == ulm_params_.seq_size_T - 1) { - return absl::OutOfRangeError( - absl::StrCat("Hit max sequence length ", ulm_params_.seq_size_T)); - } - - transformer_input_->Borrow( - full_ulm_->transformer_input_->Slice(1, decode_step)); - atten_masks_->Borrow(full_ulm_->atten_masks_->Slice(0, decode_step)); - MP_RETURN_IF_ERROR(segment_pos_->LoadFromBuffer( - full_ulm_->segment_pos_->Slice(0, decode_step)->Data())); - for (auto& kv_cache : kv_cache_) { - DCHECK(kv_cache.k_slice); - DCHECK(kv_cache.v_slice); - kv_cache.k_slice->Borrow(kv_cache.k_cache->Slice(1, decode_step)); - kv_cache.v_slice->Borrow(kv_cache.v_cache->Slice(1, decode_step)); - } - - MP_RETURN_IF_ERROR(SetupRuntime()); - MP_RETURN_IF_ERROR(Run()); - - RET_CHECK(logits_output_); - DCHECK_EQ(logits_output_->num_elements, ulm_params_.voc_size_V); - - ASSIGN_OR_RETURN(*output_ids, - mediapipe::SampleNextToken( - logits_output_->DataAs(), - /*batch_size=*/1, - /*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10, - /*top_p=*/1, /*temperature=*/-1)); - RET_CHECK_EQ(output_ids->size(), 1); - prev_ids_.push_back(output_ids->at(0)); - - return GetTokenEmbedding( - *output_ids, - pos_embedding_data_->Slice({decode_step + 1, 0})->DataAs(), - full_ulm_->transformer_input_->Slice({0, decode_step + 1, 0}) - ->DataAs()); - } - - private: - std::unique_ptr full_ulm_; -}; - -absl::StatusOr> UlmBuilder::SelfAttentionExcludeNorm( - std::shared_ptr input, SelfAttentionArgs args, - const SelfAttentionWeights& sa_weights, absl::SourceLocation loc) { - // [B, 1|T, N, H] - ASSIGN_OR_RETURN(auto k_proj, SelfAttentionProj(input, sa_weights.k_weight)); - ASSIGN_OR_RETURN(auto q_proj, SelfAttentionProj(input, sa_weights.q_weight)); - ASSIGN_OR_RETURN(auto v_proj, SelfAttentionProj(input, sa_weights.v_weight)); - - ASSIGN_OR_RETURN(auto query_proj_after_rope, Rope(q_proj, args.segment_pos)); - ASSIGN_OR_RETURN(auto key_proj_after_rope, Rope(k_proj, args.segment_pos)); - - if (args.cache) { - RET_CHECK(args.cache->k_cache); - RET_CHECK(args.cache->v_cache); - // When cache is provided, there are 2 cases: - if (*(input->dims.end() - 2) != 1) { - // Building a normal graph, which is used to initialize cache. - key_proj_after_rope->Borrow(args.cache->k_cache).MarkOutput(); - v_proj->Borrow(args.cache->v_cache).MarkOutput(); - } else { - // Building a one-token graph, which consumes initialized cache. - key_proj_after_rope->MarkOutput(); - args.cache->k_slice = key_proj_after_rope; - v_proj->MarkOutput(); - args.cache->v_slice = v_proj; - - ASSIGN_OR_RETURN(key_proj_after_rope, - NewInput(args.cache->k_cache->dims)); - key_proj_after_rope->Borrow(args.cache->k_cache); - ASSIGN_OR_RETURN(v_proj, NewInput(args.cache->v_cache->dims)); - v_proj->Borrow(args.cache->v_cache); - } - } - - // encoded, [B, 1|T, N, H] - ASSIGN_OR_RETURN( - auto kqv_merged, - DotAttention(query_proj_after_rope, key_proj_after_rope, v_proj, - args.atten_mask, sa_weights.per_dim_scale)); - - const size_t B = kqv_merged->dims[0]; - const size_t T_or_1 = kqv_merged->dims[1]; - const size_t NH = kqv_merged->num_elements / (B * T_or_1); - ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, T_or_1, NH})); - - return MatMul(outcome_reshaped, sa_weights.post_proj_weight, - {.transpose = false}); -} - -absl::StatusOr> -UlmBuilder::SelfAttentionIncludeResidual(std::shared_ptr input, - SelfAttentionArgs args, - const SelfAttentionWeights& params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto pre_attention, RmsNorm(input, params.pre_norm)); - - ASSIGN_OR_RETURN( - auto post_attention, - SelfAttentionExcludeNorm(pre_attention, std::move(args), params)); - - ASSIGN_OR_RETURN(auto post_norm, RmsNorm(post_attention, params.post_norm)); - - return ElementAdd(input, post_norm); -} - -absl::StatusOr> UlmBuilder::FeedForwardExcludeResidual( - std::shared_ptr input, const FeedForwardWeights& params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto first_rms_norm, RmsNorm(input, params.pre_norm)); - - ASSIGN_OR_RETURN(auto layer_1, FullConn(first_rms_norm, params.layer_1_weight, - params.layer_1_bias)); - - ASSIGN_OR_RETURN(auto layer_1_gate_before_gelu, - FullConn(first_rms_norm, params.layer_1_gate_weight, - params.layer_1_gate_bias)); - ASSIGN_OR_RETURN(auto layer_1_gate, Gelu(layer_1_gate_before_gelu)); - - ASSIGN_OR_RETURN(auto layer_1_and_gate, ElementMul(layer_1, layer_1_gate)); - if (params.opt_padding) { - // activations *= 1.0 - paddings - ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f)); - ASSIGN_OR_RETURN(tmp, ElementMul(layer_1_and_gate, tmp)); - ASSIGN_OR_RETURN(layer_1_and_gate, ElementAdd(tmp, layer_1_and_gate)); - } - ASSIGN_OR_RETURN( - auto layer_2, - FullConn(layer_1_and_gate, params.layer_2_weight, params.layer_2_bias)); - if (params.opt_padding) { - // activations *= 1.0 - paddings - ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f)); - ASSIGN_OR_RETURN(tmp, ElementMul(layer_2, tmp)); - ASSIGN_OR_RETURN(layer_2, ElementAdd(tmp, layer_2)); - } - - return RmsNorm(layer_2, params.post_norm); -} - -absl::StatusOr> UlmBuilder::FeedForwardIncludeResidual( - std::shared_ptr input, const FeedForwardWeights& params, - absl::SourceLocation loc) { - ASSIGN_OR_RETURN(auto before_residual, - FeedForwardExcludeResidual(input, params)); - return ElementAdd(before_residual, input); -} - -absl::StatusOr> Ulm::CreateUlm( - absl::string_view weights_folder, const UlmParams& ulm_params, - std::unique_ptr runtime_configs) { - auto weight_loader = - std::make_unique(weights_folder, ulm_params); - return CreateUlm(std::move(weight_loader), std::move(runtime_configs)); -} - -absl::StatusOr> Ulm::CreateOneTokenUlm( - std::unique_ptr weight_loader, - std::unique_ptr runtime_configs) { - UlmBuilder builder; - // TODO: might be memory waste here, benchmark. - weight_loader->SetBuilder(builder); - ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights()); - - UlmParams ulm_params = weight_loader->ulm_params(); - ulm_params.enable_kv_cache = true; - - weight_loader->ulm_params().enable_kv_cache = true; - weight_loader->ulm_params().final_norm = false; - weight_loader->ulm_params().final_project = false; - ASSIGN_OR_RETURN(auto full_ulm, CreateUlm(std::move(weight_loader))); - - ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, 1, - ulm_params.model_dim_D})); - ASSIGN_OR_RETURN(auto atten_masks, - builder.NewInput({1, ulm_params.seq_size_T})); - ASSIGN_OR_RETURN(auto segment_pos, - builder.NewWeight({1, ulm_params.head_dim_H})); - // To allocate buffer before creating runtime. - MP_RETURN_IF_ERROR(segment_pos->LoadFromVec({}, /*exact_match=*/false)); - - std::vector& kv_cache = full_ulm->kv_cache_; - RET_CHECK_EQ(kv_cache.size(), ulm_params.num_transformer_M); - - auto inter_layer = input; - for (int i = 0; i < ulm_params.num_transformer_M; ++i) { - const auto& sa = weights.sas[i]; - ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual( - inter_layer, - {.atten_mask = atten_masks, - .segment_pos = segment_pos, - .cache = &kv_cache[i]}, - sa)); - - auto& ff = weights.ffs[i]; - // ff.opt_padding = paddings; - ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff)); - } - - std::shared_ptr logits_output, transformer_output, normed_output; - - if (ulm_params.final_norm) { - ASSIGN_OR_RETURN(inter_layer, - builder.RmsNorm(inter_layer, weights.final_ln_scale)); - normed_output = inter_layer; - normed_output->MarkOutput(); - } - if (ulm_params.final_project) { - RET_CHECK(weights.softmax_linear); - ASSIGN_OR_RETURN(logits_output, - ApplyFinalProj(inter_layer, weights, builder)); - logits_output->MarkOutput(); - } - - ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs))); - Ulm* full_ulm_p = full_ulm.get(); - auto result = - std::make_unique(std::move(full_ulm), std::move(*graph)); - { - Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D}; - result->pos_embedding_data_ = - std::make_shared(std::move(dims), xnn_datatype_fp32); - result->pos_embedding_data_->Borrow(full_ulm_p->pos_embedding_data_); - } - result->transformer_input_ = input; - result->transformer_output_ = transformer_output; - result->normed_output_ = normed_output; - result->logits_output_ = logits_output; - result->segment_pos_ = segment_pos; - result->atten_masks_ = atten_masks; - if (ulm_params.use_padding) { - // result->paddings_ = paddings; - } - result->kv_cache_ = std::move(kv_cache); - - result->weights_ = std::move(weights); - result->ulm_params_ = ulm_params; - - return result; -} - -absl::StatusOr> Ulm::CreateUlm( - std::unique_ptr weight_loader, - std::unique_ptr runtime_configs) { - UlmBuilder builder; - weight_loader->SetBuilder(builder); - const auto& ulm_params = weight_loader->ulm_params(); - RET_CHECK_NE(ulm_params.batch_size_B, 0); - - ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, - ulm_params.seq_size_T, - ulm_params.model_dim_D})); - ASSIGN_OR_RETURN(auto atten_masks, builder.NewInput({ulm_params.seq_size_T, - ulm_params.seq_size_T})); - VLOG(1) << "atten mask id " << atten_masks->tensor_id; - ASSIGN_OR_RETURN( - auto segment_pos, - builder.NewWeight({ulm_params.seq_size_T, ulm_params.head_dim_H})); - MP_RETURN_IF_ERROR(FillXnnRoPEWeights(*segment_pos)); - VLOG(1) << "segment pos id " << segment_pos->tensor_id; - std::shared_ptr paddings; - if (ulm_params.use_padding) { - ASSIGN_OR_RETURN(paddings, builder.NewInput({ulm_params.batch_size_B, - ulm_params.seq_size_T, 1})); - VLOG(1) << "paddings id " << paddings->tensor_id; - } - - ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights()); - std::vector kv_cache; - - auto inter_layer = input; - for (int i = 0; i < ulm_params.num_transformer_M; ++i) { - const auto& sa = weights.sas[i]; - KVCache* cache = nullptr; - if (ulm_params.enable_kv_cache) { - auto k_cache = std::make_shared( - Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T, - ulm_params.n_heads_N, ulm_params.head_dim_H}); - MP_RETURN_IF_ERROR(k_cache->LoadFromVec({}, /*exact_match=*/false)); - auto v_cache = std::make_shared( - Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T, - ulm_params.n_heads_N, ulm_params.head_dim_H}); - MP_RETURN_IF_ERROR(v_cache->LoadFromVec({}, /*exact_match=*/false)); - kv_cache.push_back(KVCache{.k_cache = k_cache, .v_cache = v_cache}); - cache = &kv_cache.back(); - } - ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual( - inter_layer, - {.atten_mask = atten_masks, - .segment_pos = segment_pos, - .cache = cache}, - sa)); - - auto& ff = weights.ffs[i]; - ff.opt_padding = paddings; - ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff)); - } - - std::shared_ptr logits_output, transformer_output, normed_output; - - if (!ulm_params.final_norm && !ulm_params.final_project) { - transformer_output = inter_layer; - transformer_output->MarkOutput(); - } - - if (ulm_params.final_norm) { - ASSIGN_OR_RETURN(inter_layer, - builder.RmsNorm(inter_layer, weights.final_ln_scale)); - normed_output = inter_layer; - normed_output->MarkOutput(); - } - - if (ulm_params.final_project) { - RET_CHECK(weights.softmax_linear); - ASSIGN_OR_RETURN(logits_output, - ApplyFinalProj(inter_layer, weights, builder)); - logits_output->MarkOutput(); - } - - ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs))); - auto ulm = std::make_unique(std::move(*graph)); - { - ASSIGN_OR_RETURN(auto pos_embedding_data, - mediapipe::PositionEmbedding(ulm_params.seq_size_T, - ulm_params.model_dim_D)); - Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D}; - ulm->pos_embedding_data_ = - std::make_shared(std::move(dims), xnn_datatype_fp32); - MP_RETURN_IF_ERROR( - ulm->pos_embedding_data_->LoadFromVec(pos_embedding_data)); - } - ulm->transformer_input_ = input; - ulm->transformer_output_ = transformer_output; - ulm->normed_output_ = normed_output; - ulm->logits_output_ = logits_output; - ulm->segment_pos_ = segment_pos; - ulm->atten_masks_ = atten_masks; - if (ulm_params.use_padding) { - ulm->paddings_ = paddings; - } - ulm->kv_cache_ = std::move(kv_cache); - - ulm->weights_ = std::move(weights); - ulm->ulm_params_ = ulm_params; - - return ulm; -} - -absl::Status Ulm::InitInputTokens(const std::vector& input_ids) { - prev_ids_ = input_ids; - - constexpr float neg_value = 0.7 * std::numeric_limits::lowest(); - const auto& seq_size = ulm_params_.seq_size_T; - std::vector attention_array(seq_size * seq_size, neg_value); - for (int i = 0; i < seq_size; ++i) { - for (int j = 0; j < seq_size; ++j) { - if (i < input_ids.size() && j < input_ids.size()) { - attention_array[seq_size * i + j] = 0; - } else if (i >= seq_size && j <= i) { - attention_array[seq_size * i + j] = 0; - } else { - break; - } - } - } - - MP_RETURN_IF_ERROR(atten_masks_->LoadFromVec(attention_array)); - - MP_RETURN_IF_ERROR(GetTokenEmbedding(input_ids, - pos_embedding_data_->DataAs(), - transformer_input_->DataAs())); - return SetupRuntime(); -} - -absl::Status Ulm::GetNextToken(std::vector* output_ids) { - VLOG(2) << "Decode step " << prev_ids_.size() - 1; - - MP_RETURN_IF_ERROR(Run()); - - RET_CHECK(logits_output_); - std::shared_ptr logits = - logits_output_->Slice({0, prev_ids_.size() - 1, 0}); - DCHECK_EQ(logits->num_elements, ulm_params_.voc_size_V); - - ASSIGN_OR_RETURN(*output_ids, - mediapipe::SampleNextToken( - logits->DataAs(), - /*batch_size=*/1, - /*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10, - /*top_p=*/1, /*temperature=*/-1)); - RET_CHECK_EQ(output_ids->size(), 1); - prev_ids_.push_back(output_ids->at(0)); - - return GetTokenEmbedding( - *output_ids, - pos_embedding_data_->Slice({prev_ids_.size() - 1, 0})->DataAs(), - transformer_input_->Slice({0, prev_ids_.size() - 1, 0})->DataAs()); -} - -absl::Status Ulm::GetTokenEmbedding(const std::vector& ids, - const float* pos_embedding_data, - float* embedding) { - auto token_embedding = weights_.token_embedding ? weights_.token_embedding - : weights_.softmax_linear; - RET_CHECK(token_embedding->dims[0] == ulm_params_.voc_size_V) - << "shape must be [vocab_size, _], such that following Slice() makes " - "sense."; - for (size_t id : ids) { - memcpy(embedding, token_embedding->Slice(0, id)->Data(), - ulm_params_.model_dim_D * sizeof(float)); - for (size_t i = 0; i < ulm_params_.model_dim_D; ++i) { - embedding[i] += pos_embedding_data[i]; - } - pos_embedding_data += ulm_params_.model_dim_D; - embedding += ulm_params_.model_dim_D; - } - return absl::OkStatus(); -} - -} // namespace xnn_utils -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h deleted file mode 100644 index 7bf7de5a9..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h +++ /dev/null @@ -1,127 +0,0 @@ -#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_ -#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" - -namespace mediapipe { -namespace xnn_utils { - -class Ulm : public XnnGraph { - public: - using UlmParams = UlmParams; - - explicit Ulm(XnnGraph&& other) : XnnGraph(std::move(other)) {} - ~Ulm() override = default; - - // Creating ULM graph with default params. The default param corresponds to - // ULM1B 256k model. - static absl::StatusOr> CreateUlm( - absl::string_view weights_folder, - const UlmParams& ulm_params = - UlmParams{ - .num_transformer_M = 18, - .batch_size_B = 1, - .seq_size_T = 16, - .model_dim_D = 1536, - .hidden_dim_HD = 8 * 1536, - .head_dim_H = 128, - .n_heads_N = 12, - .voc_size_V = 256128, - }, - std::unique_ptr runtime_configs = nullptr); - static absl::StatusOr> CreateUlm( - std::unique_ptr weight_loader, - std::unique_ptr runtime_configs = nullptr); - // Build the graph for one-token inference. - static absl::StatusOr> CreateOneTokenUlm( - std::unique_ptr weight_loader, - std::unique_ptr runtime_configs = nullptr); - - // (Re)Initialize with input token ids. This will reset the cache, mask etc. - virtual absl::Status InitInputTokens(const std::vector& input_ids); - - // Get the next token id. - virtual absl::Status GetNextToken(std::vector* output_ids); - - protected: - friend class OneTokenUlm; - friend class UlmTest; - friend class UlmBuilder; - - // Enable if enable_kv_cache - struct KVCache { - std::shared_ptr k_cache; - std::shared_ptr v_cache; - std::shared_ptr k_slice; - std::shared_ptr v_slice; - }; - - absl::Status GetTokenEmbedding(const std::vector& ids, - const float* pos_embedding_data, - float* embedding); - - UlmWeights weights_; - UlmParams ulm_params_; - - std::shared_ptr pos_embedding_data_; - std::shared_ptr atten_masks_; - std::shared_ptr segment_pos_; - std::shared_ptr paddings_; - - std::shared_ptr transformer_input_; - std::shared_ptr transformer_output_; - std::shared_ptr normed_output_; - std::shared_ptr logits_output_; - - // Previous ids, including prompt. - std::vector prev_ids_; - // If enable_kv_cache, expect a mask of [0, ... 0, 1, 0, 0...], size 1 x T. - std::shared_ptr decode_step_mask_; - // [1, 1, ..., 1, 0, 0...], applied on cache - std::shared_ptr decode_step_mask_for_cache_; - std::vector kv_cache_; -}; - -class UlmBuilder : public XnnGraphBuilder { - public: - struct SelfAttentionArgs { - std::shared_ptr atten_mask; - std::shared_ptr segment_pos; - - Ulm::KVCache* cache = nullptr; - }; - - absl::StatusOr> SelfAttentionExcludeNorm( - std::shared_ptr input, SelfAttentionArgs args, - const SelfAttentionWeights& sa_weights, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> SelfAttentionIncludeResidual( - std::shared_ptr input, SelfAttentionArgs args, - const SelfAttentionWeights& params, - absl::SourceLocation loc = absl::SourceLocation::current()); - - absl::StatusOr> FeedForwardExcludeResidual( - std::shared_ptr input, const FeedForwardWeights& params, - absl::SourceLocation loc = absl::SourceLocation::current()); - absl::StatusOr> FeedForwardIncludeResidual( - std::shared_ptr input, const FeedForwardWeights& params, - absl::SourceLocation loc = absl::SourceLocation::current()); -}; - -} // namespace xnn_utils -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc deleted file mode 100644 index a33589a60..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.cc +++ /dev/null @@ -1,366 +0,0 @@ -#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "file/base/filesystem.h" -#include "file/base/options.h" -#include "file/base/path.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" -#include "third_party/XNNPACK/include/xnnpack.h" - -namespace mediapipe { -namespace xnn_utils { - -namespace { - -absl::StatusOr> LoadFromAbsPathPrefixHelper( - XnnGraphBuilder& builder, absl::string_view prefix, - const Tensor::DimsType& dims, size_t dim_scale_if_any) { - RET_CHECK(!prefix.empty() && prefix.back() != '.'); - std::vector filenames; - auto s = file::Match(absl::StrCat(prefix, "*"), &filenames, file::Defaults()); - if (!s.ok()) { - LOG(WARNING) << s; - return nullptr; - } else if (filenames.empty()) { - return nullptr; - } - - if (filenames.size() == 1) { - RET_CHECK_EQ(filenames[0], prefix); - return builder.NewWeight(filenames[0], dims); - } - - bool is_quantized_tensor = false; - for (const auto& filename : filenames) { - if (absl::StrContains(filename, kQuantizedScaleSuffix)) { - is_quantized_tensor = true; - continue; - } - } - - RET_CHECK(is_quantized_tensor) - << "At least one of {" << filenames << "} must be quantize scale file."; - - std::shared_ptr result; - result = std::make_shared(dims, dim_scale_if_any); - - MP_RETURN_IF_ERROR(result->LoadFromFile(prefix)); - builder.NewWeight(result); - - return result; -} - -absl::Status TransposeSelfAttentionWeight( - const UlmWeightsLoader& loader, std::shared_ptr& original_weight, - absl::string_view cache_file_prefix) { - const auto& ulm_param = loader.ulm_params(); - RET_CHECK(original_weight); - - std::optional from_cache = - original_weight->GetMetadata(UlmWeights::kKeyLoadedFromCache); - if (from_cache && *from_cache) { - return absl::OkStatus(); - } - - if (auto s = original_weight->DumpToFile(cache_file_prefix); !s.ok()) { - LOG(WARNING) << s; - } else { - MP_RETURN_IF_ERROR(original_weight->LoadFromFile(cache_file_prefix)); - } - loader.builder().NewWeight(original_weight); - original_weight->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight, - ulm_param.n_heads_N); - return absl::OkStatus(); -} - -} // namespace - -absl::Status PrepareTokenEmbeddingDecorator::Decorate( - const UlmWeightsLoader& loader, UlmWeights& weight) { - if (weight.token_embedding) { - return absl::OkStatus(); - } - - const auto& ulm_params = loader.ulm_params(); - absl::string_view cache_path = loader.ulm_params().weight_cache_path; - std::string token_embedding_cache_path = - cache_path.empty() ? "" : file::JoinPath(cache_path, "token_embedding.w"); - // 1. try cache - if (!token_embedding_cache_path.empty()) { - auto token_embedding = - Tensor::FromFile(token_embedding_cache_path, - {ulm_params.voc_size_V, ulm_params.model_dim_D}); - if (token_embedding.ok()) { - weight.token_embedding = *token_embedding; - return absl::OkStatus(); - } - } - - // 2. fill embedding from softmax_linear - auto& softmax_linear = *weight.softmax_linear; - RET_CHECK(softmax_linear.dims[0] == ulm_params.voc_size_V) << softmax_linear; - if (softmax_linear.datatype == xnn_datatype_fp32) { - weight.token_embedding = softmax_linear.View(); - } else if (softmax_linear.datatype == xnn_datatype_qcint8) { - ASSIGN_OR_RETURN(weight.token_embedding, softmax_linear.ConvertToF32()); - } - - float* embedding_data = weight.token_embedding->DataAs(); - for (size_t i = 0; i < softmax_linear.num_elements; ++i) { - embedding_data[i] *= std::sqrt(loader.ulm_params().model_dim_D); - } - - // 3. save cache - if (!token_embedding_cache_path.empty()) { - MP_RETURN_IF_ERROR( - weight.token_embedding->DumpToFile(token_embedding_cache_path)); - return weight.token_embedding->LoadFromFile(token_embedding_cache_path); - } - - return absl::OkStatus(); -} - -absl::Status TransposeSelfAttentionWeightDecorator::Decorate( - const UlmWeightsLoader& loader, UlmWeights& weight) { - absl::string_view cache_path = loader.ulm_params().weight_cache_path; - if (cache_path.empty()) { - return absl::OkStatus(); - } - - for (size_t i = 0; i < weight.sas.size(); ++i) { - auto& sa = weight.sas[i]; - auto prefix = absl::StrCat(UlmWeightsLoader::kTransformerWeightPrefix, i, - ".self_attention."); - MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight( - loader, sa.k_weight, - file::JoinPath(cache_path, absl::StrCat(prefix, "k.w")))); - MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight( - loader, sa.q_weight, - file::JoinPath(cache_path, absl::StrCat(prefix, "q.w")))); - MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight( - loader, sa.v_weight, - file::JoinPath(cache_path, absl::StrCat(prefix, "v.w")))); - } - - return absl::OkStatus(); -} - -absl::StatusOr> UlmWeightsLoader::LoadFromAbsPathPrefix( - absl::string_view prefix, const Tensor::DimsType& dims, - size_t dim_scale_if_any) const { - return LoadFromAbsPathPrefixHelper(*builder_, prefix, dims, dim_scale_if_any); -} - -absl::StatusOr> -UlmWeightsLoader::TryCacheThenLoadSelfAttention( - absl::string_view filename_prefix) const { - ASSIGN_OR_RETURN( - auto r, - TryCacheThenLoadWeightTranspose( - filename_prefix, - {params_.model_dim_D, params_.n_heads_N * params_.head_dim_H}, 1)); - r->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight, - params_.n_heads_N); - return r; -} - -absl::StatusOr> -UlmWeightsLoader::TryCacheThenLoadFeedForward( - absl::string_view filename_prefix, - std::optional dims) const { - if (!dims) { - dims = {params_.model_dim_D, params_.hidden_dim_HD}; - } - return TryCacheThenLoadWeightTranspose(filename_prefix, *dims, 1); -} - -absl::StatusOr> -UlmWeightsLoader::TryCacheThenLoadWeightTranspose( - absl::string_view filename_prefix, Tensor::DimsType original_dims, - size_t original_dim_cale) const { - if (!params_.weight_cache_path.empty()) { - auto cache_full_prefix = - file::JoinPath(params_.weight_cache_path, filename_prefix); - Tensor::DimsType cache_dim{original_dims.rbegin(), original_dims.rend()}; - ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix( - cache_full_prefix, std::move(cache_dim), - /*dim_scale_if_any=*/1 - original_dim_cale)); - if (r) { - r->SetMetadata(UlmWeights::kKeyLoadedFromCache, 1); - return r; - } - } - - ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix( - file::JoinPath(weight_path_, filename_prefix), - std::move(original_dims), - /*dim_scale_if_any=*/original_dim_cale)); - RET_CHECK(r) << file::JoinPath(weight_path_, filename_prefix); - r = r->Transpose(); - builder_->NewWeight(r); - return r; -} - -absl::StatusOr UlmWeightsLoader::LoadFeedForward( - int layer_id) { - absl::string_view weights_folder = weight_path_; - const auto& params = params_; - auto ff_file_prefix = - absl::StrCat(kTransformerWeightPrefix, layer_id, ".ff_layer."); - auto ff_prefix = file::JoinPath(weights_folder, ff_file_prefix); - FeedForwardWeights feed_forward; - - ASSIGN_OR_RETURN( - feed_forward.pre_norm, - LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "pre_layer_norm.scale"), - {params.model_dim_D})); - ASSIGN_OR_RETURN( - feed_forward.post_norm, - LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "post_layer_norm.scale"), - {params.model_dim_D})); - ASSIGN_OR_RETURN( - feed_forward.layer_1_bias, - LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1.bias.b"), - {params.hidden_dim_HD})); - ASSIGN_OR_RETURN(feed_forward.layer_1_weight, - TryCacheThenLoadFeedForward( - absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w"))); - ASSIGN_OR_RETURN( - feed_forward.layer_1_gate_bias, - LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1_gate.bias.b"), - {params.hidden_dim_HD})); - ASSIGN_OR_RETURN(feed_forward.layer_1_gate_weight, - TryCacheThenLoadFeedForward(absl::StrCat( - ff_file_prefix, "ffn_layer1_gate.linear.w"))); - ASSIGN_OR_RETURN( - feed_forward.layer_2_bias, - LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer2.bias.b"), - {params.model_dim_D}, /*dim_scale_if_any=*/0)); - ASSIGN_OR_RETURN( - feed_forward.layer_2_weight, - TryCacheThenLoadFeedForward( - absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"), - Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D})); - - return feed_forward; -} - -absl::StatusOr UlmWeightsLoader::LoadSelfAttention( - int layer_id) { - absl::string_view weights_folder = weight_path_; - const auto& params = params_; - SelfAttentionWeights self_attention; - - auto sa_file_prefix = absl::StrCat(kTransformerWeightPrefix, layer_id); - auto sa_prefix = file::JoinPath(weights_folder, sa_file_prefix); - ASSIGN_OR_RETURN( - self_attention.pre_norm, - LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".pre_layer_norm.scale"), - {params.model_dim_D})); - ASSIGN_OR_RETURN( - self_attention.post_norm, - LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".post_layer_norm.scale"), - {params.model_dim_D})); - - absl::StrAppend(&sa_file_prefix, ".self_attention."); - - ASSIGN_OR_RETURN( - self_attention.k_weight, - TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w"))); - ASSIGN_OR_RETURN( - self_attention.q_weight, - TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w"))); - ASSIGN_OR_RETURN( - self_attention.v_weight, - TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w"))); - - sa_prefix = file::JoinPath(weights_folder, sa_file_prefix); - ASSIGN_OR_RETURN(self_attention.per_dim_scale, - LoadFromAbsPathPrefix( - absl::StrCat(sa_prefix, "per_dim_scale.per_dim_scale"), - {params.head_dim_H})); - ASSIGN_OR_RETURN(self_attention.post_proj_weight, - LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, "post.w"), - {params.model_dim_D, - params.n_heads_N * params.head_dim_H}, - /*dim_scale_if_any=*/0)); - - return self_attention; -} - -absl::StatusOr UlmWeightsLoader::LoadWeights() { - absl::string_view weights_folder = weight_path_; - const auto& params = params_; - UlmWeights result; - - for (int layer_id = 0; layer_id < params.num_transformer_M; ++layer_id) { - ASSIGN_OR_RETURN(auto ff, LoadFeedForward(layer_id)); - result.ffs.push_back(std::move(ff)); - ASSIGN_OR_RETURN(auto sa, LoadSelfAttention(layer_id)); - result.sas.push_back(std::move(sa)); - } - if (params.final_norm) { - ASSIGN_OR_RETURN(result.final_ln_scale, - LoadFromAbsPathPrefix( - file::JoinPath(weights_folder, kFinalScaleFilename), - {params.model_dim_D})); - } - ASSIGN_OR_RETURN(result.softmax_bias, - LoadFromAbsPathPrefix( - file::JoinPath(weights_folder, kLogitsFfnBiasFilename), - {params.voc_size_V})); - ASSIGN_OR_RETURN(result.softmax_linear, - TryCacheThenLoadWeightTranspose( - kLogitsFfnWeightFilename, - {params.model_dim_D, params.voc_size_V}, 1)); - - return result; -} - -BenchmarkUlmWeightsLoader::BenchmarkUlmWeightsLoader(const UlmParams& params, - xnn_datatype data_type) - : DefaultUlmWeightsLoader("", params), data_type_(data_type) { - params_.weight_cache_path.clear(); -} - -absl::StatusOr> -BenchmarkUlmWeightsLoader::TryCacheThenLoadWeightTranspose( - absl::string_view filename_prefix, Tensor::DimsType original_dims, - size_t original_dim_cale) const { - auto result = std::make_shared( - Tensor::DimsType{original_dims.rbegin(), original_dims.rend()}, - 1 - original_dim_cale); - auto real_data = std::make_shared(result->num_elements, 0xA5); - result->flat_data = std::shared_ptr(real_data, real_data->data()); - auto real_scale = std::make_shared>( - original_dims[original_dim_cale], 1.0f); - result->scale_data = std::shared_ptr(real_scale, real_scale->data()); - builder_->NewWeight(result); - return result; -} - -absl::StatusOr> -BenchmarkUlmWeightsLoader::LoadFromAbsPathPrefix( - absl::string_view prefix, const Tensor::DimsType& dims, - size_t dim_scale_if_any) const { - // If loader calls this function directly, it's always non-quantized weights. - auto result = std::make_shared(dims); - MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false)); - builder_->NewWeight(result); - return result; -} - -} // namespace xnn_utils -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h b/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h deleted file mode 100644 index f10d8706a..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h +++ /dev/null @@ -1,192 +0,0 @@ -#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_ -#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_ - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" -#include "third_party/XNNPACK/include/xnnpack.h" - -namespace mediapipe { -namespace xnn_utils { - -struct UlmParams { - size_t num_transformer_M = 18; - size_t batch_size_B = 1; - size_t seq_size_T = 16; - size_t model_dim_D = 1536; - size_t hidden_dim_HD = 8 * 1536; - size_t head_dim_H = 128; - size_t n_heads_N = 12; - size_t voc_size_V = 32000; - - bool use_padding = true; - bool final_norm = true; - bool final_project = true; - - bool enable_kv_cache = false; - // Path to store reshaped weights as cache. Set empty to disable caching. - std::string weight_cache_path; -}; - -struct SelfAttentionWeights { - std::shared_ptr pre_norm; - - std::shared_ptr k_weight; - std::shared_ptr q_weight; - std::shared_ptr v_weight; - std::shared_ptr per_dim_scale; - std::shared_ptr post_proj_weight; - - std::shared_ptr post_norm; -}; - -struct FeedForwardWeights { - std::shared_ptr pre_norm; - std::shared_ptr layer_1_weight; - std::shared_ptr layer_1_bias; - std::shared_ptr layer_1_gate_weight; - std::shared_ptr layer_1_gate_bias; - std::shared_ptr layer_2_weight; - std::shared_ptr layer_2_bias; - std::shared_ptr post_norm; - - std::shared_ptr opt_padding; -}; - -struct UlmWeights { - std::vector ffs; - std::vector sas; - std::shared_ptr final_ln_scale; - std::shared_ptr softmax_linear; - std::shared_ptr softmax_bias; - - // Optional. Usually softmax_linear can be used as embedding, but sometimes we - // need to scale/transpose it. - std::shared_ptr token_embedding; - - static constexpr absl::string_view kKeyLoadedFromCache{"loaded_from_cache"}; -}; - -class UlmWeightsLoader { - public: - constexpr static absl::string_view kTransformerWeightPrefix{ - "params.lm.transformer.x_layers_"}; - constexpr static absl::string_view kFinalScaleFilename{ - "params.lm.final_ln.scale"}; - constexpr static absl::string_view kLogitsFfnBiasFilename{ - "params.lm.softmax.logits_ffn.bias.b"}; - constexpr static absl::string_view kLogitsFfnWeightFilename{ - "params.lm.softmax.logits_ffn.linear.w"}; - - UlmWeightsLoader(absl::string_view weight_path, const UlmParams& params) - : weight_path_(weight_path), params_(params) {} - virtual ~UlmWeightsLoader() = default; - - void SetBuilder(XnnGraphBuilder& builder) { builder_ = &builder; } - - virtual absl::StatusOr LoadWeights(); - - virtual absl::StatusOr LoadSelfAttention(int layer_id); - virtual absl::StatusOr LoadFeedForward(int layer_id); - - UlmParams& ulm_params() { return params_; } - const UlmParams& ulm_params() const { return params_; } - XnnGraphBuilder& builder() const { return *builder_; } - - protected: - // Find the files that matches prefix, then read from file. - virtual absl::StatusOr> LoadFromAbsPathPrefix( - absl::string_view prefix, const Tensor::DimsType& dims, - size_t dim_scale_if_any) const; - absl::StatusOr> LoadFromAbsPathPrefix( - absl::string_view prefix, const Tensor::DimsType& dims) const { - return LoadFromAbsPathPrefix(prefix, dims, 0); - } - - absl::StatusOr> TryCacheThenLoadSelfAttention( - absl::string_view filename_prefix) const; - absl::StatusOr> TryCacheThenLoadFeedForward( - absl::string_view filename_prefix, - std::optional dims = std::nullopt) const; - virtual absl::StatusOr> - TryCacheThenLoadWeightTranspose(absl::string_view filename_prefix, - Tensor::DimsType original_dims, - size_t original_dim_cale) const; - - std::string weight_path_; - UlmParams params_; - XnnGraphBuilder* builder_ = nullptr; -}; - -// Try: 1. load token embedding from cache; 2. fill token embedding by transpose -// softmax linear then scale; 3. dump token embedding to cache. -struct PrepareTokenEmbeddingDecorator { - static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&); -}; -struct TransposeSoftmaxWeightDecorator { - static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&); -}; -struct TransposeSelfAttentionWeightDecorator { - // If KQV weight are reshaped, ignore. - // If KQV weight are not properly shaped, load from cache if any, or build. - // If KQV weight are missing, try loading from cache path, or fail if missing. - static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&); -}; - -// Apply some decoration (in order) to the weights loaded by base class. -template -class UlmWeightsLoaderWith : public UlmWeightsLoader { - public: - UlmWeightsLoaderWith(absl::string_view weight_path, const UlmParams& params) - : UlmWeightsLoader(weight_path, params), - decorators_{Decorators::Decorate...} {} - - absl::StatusOr LoadWeights() override { - ASSIGN_OR_RETURN(auto result, UlmWeightsLoader::LoadWeights()); - for (const auto& decorator : decorators_) { - MP_RETURN_IF_ERROR(decorator(*this, result)); - } - return result; - } - - protected: - std::vector> - decorators_; -}; - -using DefaultUlmWeightsLoader = - UlmWeightsLoaderWith; - -// Generate weights with some random value. -class BenchmarkUlmWeightsLoader : public DefaultUlmWeightsLoader { - public: - explicit BenchmarkUlmWeightsLoader( - const UlmParams& params, xnn_datatype data_type = xnn_datatype_fp32); - - absl::StatusOr> TryCacheThenLoadWeightTranspose( - absl::string_view filename_prefix, Tensor::DimsType original_dims, - size_t original_dim_cale) const override; - - absl::StatusOr> LoadFromAbsPathPrefix( - absl::string_view prefix, const Tensor::DimsType& dims, - size_t dim_scale_if_any) const override; - - private: - xnn_datatype data_type_; - std::shared_ptr random_value_buffer_; -}; - -} // namespace xnn_utils -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc deleted file mode 100644 index 8407892af..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/utils.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" - -namespace mediapipe { -namespace xnn_utils { - -std::vector FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels) { - std::vector out_array(max_seq_len * num_channels); - for (size_t ch_id = 0; ch_id < num_channels / 2; ++ch_id) { - auto timescale = std::pow(1e-4, 2.0 * ch_id / num_channels); - for (size_t seq_id = 0; seq_id < max_seq_len; ++seq_id) { - auto sinusoid_inp = seq_id * timescale; - out_array[seq_id * num_channels + ch_id] = cos(sinusoid_inp); - out_array[seq_id * num_channels + ch_id + num_channels / 2] = - sin(sinusoid_inp); - } - } - return out_array; -} - -} // namespace xnn_utils -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/utils.h b/mediapipe/tasks/cc/text/utils/xnn_utils/utils.h deleted file mode 100644 index 7aea30521..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/utils.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_ -#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_ - -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "absl/status/statusor.h" -#include "file/base/helpers.h" -#include "file/base/options.h" -#include "mediapipe/framework/port/ret_check.h" - -namespace mediapipe { -namespace xnn_utils { - -std::vector FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels); - -// expect_size_bytes == 0 means don't check size. -template -static absl::StatusOr> LoadBufferFromFile( - absl::string_view file_path, bool use_mmap = true, - size_t expect_size_bytes = 0) { - if (use_mmap) { - int fd = open(file_path.data(), O_RDONLY); - RET_CHECK_GE(fd, 0) << "open " << file_path << " failed"; - auto cleanup = absl::MakeCleanup([fd] { close(fd); }); - - const size_t size = lseek(fd, 0, SEEK_END); - if (expect_size_bytes) { - RET_CHECK_EQ(expect_size_bytes, size) - << "File size " << size << ", expected " << expect_size_bytes - << ", file path " << file_path; - } - - void* data = mmap(/*addr=*/nullptr, size, /*prot=*/PROT_READ, - /*flags=*/MAP_SHARED, fd, /*offset=*/0); - RET_CHECK_NE(data, MAP_FAILED); - RET_CHECK_NE(data, nullptr); - - return std::shared_ptr(static_cast(data), - [](auto* p) {}); - } else { - auto read_buffer = std::make_shared(); - MP_RETURN_IF_ERROR( - file::GetContents(file_path, read_buffer.get(), file::Defaults())); - - if (expect_size_bytes) { - RET_CHECK_EQ(expect_size_bytes, read_buffer->size()) - << "File size " << read_buffer->size() << ", expected " - << expect_size_bytes << ", file path " << file_path; - } - - return std::shared_ptr( - read_buffer, reinterpret_cast(read_buffer->data())); - } -} - -} // namespace xnn_utils -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc b/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc deleted file mode 100644 index 8d185ebd9..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.cc +++ /dev/null @@ -1,358 +0,0 @@ -#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "file/base/helpers.h" -#include "file/base/options.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" -#include "third_party/XNNPACK/include/xnnpack.h" - -namespace mediapipe { -namespace xnn_utils { - -absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos) { - RET_CHECK_EQ(out_seg_pos.dims.size(), 2); - const size_t max_seq_len = out_seg_pos.dims[0]; - const size_t num_channels = out_seg_pos.dims[1]; - return out_seg_pos.LoadFromVec(FillXnnRoPEWeights(max_seq_len, num_channels)); -} - -std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { - os << "Tensor{dims=[" << tensor.dims << "], datatype=" << tensor.datatype - << ", num_elements=" << tensor.num_elements << "}"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const QCTensor& tensor) { - os << "QCTensor{dims=[" << tensor.dims << "], dim_scale=" << tensor.dim_scale - << " datatype=" << tensor.datatype - << ", num_elements=" << tensor.num_elements << "}"; - return os; -} - -bool Tensor::operator==(const Tensor& other) const { - if (dims.size() != other.dims.size()) { - return false; - } else if (datatype != other.datatype) { - return false; - } else { - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] != other.dims[i]) { - return false; - } - } - } - return 0 == memcmp(Data(), other.Data(), num_elements * ElementSize()); -} - -void Tensor::AllocateBufferIfNeeded() { - if (!flat_data) { - auto real_buffer = std::make_shared(); - real_buffer->reserve(num_elements * ElementSize() + XNN_EXTRA_BYTES); - flat_data = std::shared_ptr(real_buffer, real_buffer->data()); - } -} - -void* Tensor::Data() { - DCHECK(flat_data) - << "If this is weight, you may need to call one of the LoadFrom*()"; - return flat_data.get(); -} - -std::shared_ptr Tensor::Slice(DimsType offset) { - DCHECK(flat_data); - CHECK_EQ(offset.size(), dims.size()) << offset << " vs. " << dims; - // offset: [0, k, 0, 0], dims: [1, K, _, _]. dims before k must be 1. - bool found_non_zero_offset = false; - int index_k = -1; - for (int i = 0; i < dims.size(); ++i) { - if (found_non_zero_offset) { - DCHECK_EQ(offset[i], 0); - } else if (offset[i] != 0) { - found_non_zero_offset = true; - index_k = i; - } - } - DCHECK(found_non_zero_offset) << offset; - - return Slice(index_k, offset[index_k]); -} - -std::shared_ptr Tensor::Slice(size_t index, size_t offset) { - size_t num_elements_offset = 1; - DimsType new_dim = dims; - for (int i = 0; i < dims.size(); ++i) { - if (i < index) { - DCHECK_EQ(dims[i], 1); - } else if (i == index) { - num_elements_offset *= offset; - new_dim[i] = 1; - } else { - num_elements_offset *= dims[i]; - } - } - - auto result = std::make_shared(std::move(new_dim), datatype); - result->flat_data = std::shared_ptr( - flat_data, flat_data.get() + num_elements_offset * ElementSize()); - return result; -} - -Tensor& Tensor::Borrow(std::shared_ptr other, size_t element_offset) { - DCHECK_EQ(datatype, other->datatype); - DCHECK_EQ(dims.size(), other->dims.size()); - flat_data = std::shared_ptr( - other->flat_data, - other->flat_data.get() + element_offset * ElementSize()); - return *this; -} - -std::shared_ptr Tensor::View() { return View(dims); } - -std::shared_ptr Tensor::View(DimsType as_dims, size_t) { - auto result = std::make_shared(as_dims, datatype); - DCHECK_LE(result->num_elements, num_elements); - result->flat_data = flat_data; - return result; -} - -const void* Tensor::Data() const { return const_cast(this)->Data(); } - -absl::Status Tensor::DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags) { - uint32_t id; - RET_CHECK_EQ(xnn_status_success, - xnn_define_tensor_value(&subgraph, datatype, dims.size(), - dims.data(), /*data=*/nullptr, - /*external_id=*/tensor_id, flags, &id)); - if (tensor_id == XNN_INVALID_VALUE_ID) { - RET_CHECK_NE(id, XNN_INVALID_VALUE_ID); - tensor_id = id; - } else { - RET_CHECK_EQ(id, tensor_id); - } - return absl::OkStatus(); -} - -absl::Status Tensor::DefineAsInput(xnn_subgraph& subgraph) { - return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT); -} - -absl::Status Tensor::DefineAsOutput(xnn_subgraph& subgraph) { - return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_OUTPUT); -} - -absl::Status Tensor::DefineAsIntermediateTensor(xnn_subgraph& subgraph) { - RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID); - return DefineAsExternal(subgraph, 0); -} - -absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) { - RET_CHECK_EQ( - xnn_status_success, - xnn_define_tensor_value(&subgraph, datatype, dims.size(), dims.data(), - Data(), tensor_id, flags, &tensor_id)); - RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID); - return absl::OkStatus(); -} - -absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph) { - RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID); - return DefineWeight(subgraph, 0); -} - -absl::Status Tensor::DefineRope(xnn_subgraph& subgraph) { - RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID); - return DefineWeight(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT); -} - -absl::Status Tensor::LoadFromBuffer(const void* buffer) { - AllocateBufferIfNeeded(); - memcpy(Data(), buffer, num_elements * ElementSize()); - return absl::OkStatus(); -} - -absl::Status Tensor::LoadFromVec(const std::vector& data, - bool exact_match) { - AllocateBufferIfNeeded(); - if (exact_match) { - RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float)); - } - - memcpy(Data(), data.data(), data.size() * sizeof(float)); - - return absl::OkStatus(); -} - -absl::Status Tensor::LoadFromVec(std::vector&& data, bool exact_match) { - if (exact_match) { - RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float)); - } - - auto real_buffer = std::make_shared>(std::move(data)); - if (real_buffer->size() < num_elements) { - real_buffer->resize(num_elements); - } - flat_data = std::shared_ptr( - real_buffer, reinterpret_cast(real_buffer->data())); - - return absl::OkStatus(); -} - -absl::Status Tensor::DumpToBuffer(void* buffer) { - memcpy(buffer, Data(), num_elements * ElementSize()); - return absl::OkStatus(); -} - -absl::Status Tensor::DumpToVec(std::vector& out_data, bool exact_match) { - if (exact_match) { - RET_CHECK_EQ(num_elements * ElementSize(), out_data.size() * sizeof(float)); - } else { - out_data.resize(num_elements); - } - memcpy(out_data.data(), Data(), num_elements * ElementSize()); - return absl::OkStatus(); -} - -absl::Status Tensor::DumpToFile(absl::string_view file_path) { - return file::SetContents( - file_path, - absl::string_view(flat_data.get(), num_elements * ElementSize()), - file::Defaults()); -} - -absl::Status Tensor::LoadFromFile(absl::string_view file_path, bool use_mmap, - bool exact_match) { - const size_t expected_size_in_bytes = - exact_match ? num_elements * ElementSize() : 0; - - ASSIGN_OR_RETURN(flat_data, LoadBufferFromFile(file_path, use_mmap, - expected_size_in_bytes)); - return absl::OkStatus(); -} - -std::shared_ptr Tensor::Transpose() { - DCHECK_EQ(dims.size(), 2); - DimsType out_dims{dims.rbegin(), dims.rend()}; - auto result = std::make_shared(std::move(out_dims), datatype); - result->AllocateBufferIfNeeded(); - xnn_status s; - const DimsType perm{1, 0}; - if (datatype == xnn_datatype_fp32) { - s = xnn_run_transpose_nd_x32(Data(), result->Data(), dims.size(), - dims.data(), perm.data(), - /*flags=*/0, /*threadpool=*/nullptr); - } else { - LOG(FATAL) << "Need update to support new type"; - } - DCHECK_EQ(s, xnn_status_success); - return (s == xnn_status_success) ? result : nullptr; -} - -absl::StatusOr> Tensor::ConvertToF32() { - auto result = std::make_shared(dims, xnn_datatype_fp32); - MP_RETURN_IF_ERROR(result->LoadFromBuffer(Data())); - return result; -} - -absl::Status QCTensor::LoadFromFile(absl::string_view quantized_weight_filename, - absl::string_view scale_filename, - bool use_mmap, bool exact_match) { - size_t scale_element_size = dims[dim_scale]; - - ASSIGN_OR_RETURN(flat_data, - LoadBufferFromFile(quantized_weight_filename, use_mmap, - exact_match ? num_elements : 0)); - ASSIGN_OR_RETURN(scale_data, - LoadBufferFromFile( - scale_filename, use_mmap, - exact_match ? scale_element_size * sizeof(float) : 0)); - return absl::OkStatus(); -} - -absl::Status QCTensor::DumpToFile(absl::string_view file_path) { - MP_RETURN_IF_ERROR(file::SetContents( - file_path, - absl::string_view(flat_data.get(), num_elements * ElementSize()), - file::Defaults())); - return file::SetContents( - absl::StrCat(file_path, kQuantizedScaleSuffix), - absl::string_view(reinterpret_cast(scale_data.get()), - dims[dim_scale] * sizeof(float)), - file::Defaults()); -} - -absl::Status QCTensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) { - RET_CHECK_EQ( - xnn_status_success, - xnn_define_channelwise_quantized_tensor_value( - &subgraph, datatype, scale_data.get(), dims.size(), dim_scale, - dims.data(), Data(), XNN_INVALID_VALUE_ID, flags, &tensor_id)) - << *this; - RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID); - return absl::OkStatus(); -} - -void QCTensor::AllocateBufferIfNeeded() { - Tensor::AllocateBufferIfNeeded(); - if (!scale_data) { - auto real_buffer = std::make_shared>(); - real_buffer->reserve(dims[dim_scale]); - scale_data = std::shared_ptr(real_buffer, real_buffer->data()); - } -} - -std::shared_ptr QCTensor::Transpose() { - DCHECK_EQ(dims.size(), 2); - size_t channel_size = dims[dim_scale]; - DimsType out_dims{dims.rbegin(), dims.rend()}; - auto result = std::make_shared(std::move(out_dims), 1 - dim_scale); - result->AllocateBufferIfNeeded(); - memcpy(result->scale_data.get(), scale_data.get(), - channel_size * sizeof(float)); - xnn_status s; - const DimsType perm{1, 0}; - if (datatype == xnn_datatype_qcint8) { - s = xnn_run_transpose_nd_x8(Data(), result->Data(), dims.size(), - dims.data(), perm.data(), - /*flags=*/0, /*threadpool=*/nullptr); - } else { - LOG(FATAL) << "Need update to support new type"; - } - DCHECK_EQ(s, xnn_status_success); - return (s == xnn_status_success) ? result : nullptr; -} - -absl::StatusOr> QCTensor::ConvertToF32() { - auto result = std::make_shared(dims, xnn_datatype_fp32); - // TODO: proper implement. - LOG(WARNING) << "This is fake impl"; - MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false)); - return result; -} - -std::shared_ptr QCTensor::View(DimsType as_dims, - size_t dim_scale_if_any) { - auto result = std::make_shared(as_dims, dim_scale_if_any); - DCHECK_LE(result->num_elements, num_elements); - result->flat_data = flat_data; - result->scale_data = scale_data; - return result; -} - -} // namespace xnn_utils -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h b/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h deleted file mode 100644 index 10324ff4f..000000000 --- a/mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h +++ /dev/null @@ -1,202 +0,0 @@ -#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_ -#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_ - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "file/base/helpers.h" -#include "file/base/options.h" -#include "mediapipe/framework/port/status_macros.h" -#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h" -#include "third_party/XNNPACK/include/xnnpack.h" -#include "util/gtl/stl_logging.h" - -namespace mediapipe { -namespace xnn_utils { - -static constexpr absl::string_view kQuantizedScaleSuffix{"_quantized_scale"}; -static constexpr absl::string_view kSparsityParamsSuffix{"_sparsity_params"}; - -struct Tensor { - using DimsType = std::vector; - - explicit Tensor(DimsType in_dims, xnn_datatype datatype_ = xnn_datatype_fp32) - : dims(std::move(in_dims)), - num_elements(dims.empty() ? 0 - : std::accumulate(std::begin(dims), - std::end(dims), size_t(1), - std::multiplies())), - datatype(datatype_) {} - Tensor(Tensor&& other) = default; - - Tensor& operator=(const Tensor& other) = delete; - Tensor& operator=(Tensor&& other) = default; - - virtual ~Tensor() = default; - - bool operator==(const Tensor& other) const; - - void SetMetadata(absl::string_view key, int value) { metadata[key] = value; } - - std::optional GetMetadata(absl::string_view key) const { - if (metadata.contains(key)) { - return metadata.at(key); - } - return std::nullopt; - } - - // Read weights from file. - template - static absl::StatusOr> FromFile( - absl::string_view file_path, DimsType dims, bool use_mmap = true) { - auto result = std::make_shared(std::move(dims), xnn_datatype_); - - MP_RETURN_IF_ERROR( - result->LoadFromFile(file_path, use_mmap, /*exact_match=*/true)); - - return result; - } - - virtual absl::Status DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags); - absl::Status DefineAsInput(xnn_subgraph& subgraph); - absl::Status DefineAsOutput(xnn_subgraph& subgraph); - absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph); - virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags); - absl::Status DefineWeight(xnn_subgraph& subgraph); - absl::Status DefineRope(xnn_subgraph& subgraph); - - absl::Status LoadFromBuffer(const void* buffer); - absl::Status LoadFromVec(const std::vector& data, - bool exact_match = true); - absl::Status LoadFromVec(std::vector&& data, bool exact_match = true); - absl::Status LoadFromFile(absl::string_view file_path) { - return LoadFromFile(file_path, true, true); - } - virtual absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap, - bool exact_match); - - absl::Status DumpToBuffer(void* buffer); - absl::Status DumpToVec(std::vector& out_data, bool exact_match = true); - virtual absl::Status DumpToFile(absl::string_view file_path); - - // If ith offset is 0, view's ith dim equals to original ith dim, otherwise 1. - std::shared_ptr Slice(DimsType offset); - // Slice along the `index`th dimension, offset at this dimension. - std::shared_ptr Slice(size_t index, size_t offset); - - // Point the underline data to the borrowed tensor's data. - Tensor& Borrow(std::shared_ptr, size_t element_offset = 0); - std::shared_ptr View(); - virtual std::shared_ptr View(DimsType as_dims, - size_t dim_scale_if_any = 0); - - Tensor& MarkOutput() { - AllocateBufferIfNeeded(); - is_output_tensor = true; - return *this; - } - - virtual void* Data(); - const void* Data() const; - - template - T* DataAs() { - DCHECK_EQ(ElementSize(), sizeof(T)); - return static_cast(Data()); - } - template - const T* DataAs() const { - return static_cast(Data()); - } - - virtual std::shared_ptr Transpose(); - - virtual absl::StatusOr> ConvertToF32(); - - DimsType dims; - size_t num_elements = 0; - xnn_datatype datatype = xnn_datatype_invalid; - uint32_t tensor_id = XNN_INVALID_VALUE_ID; - - // shared_ptr to make TensorMetadata copyable. - std::shared_ptr flat_data; - - protected: - friend class XnnGraphBuilder; - friend class XnnGraph; - - // Actually allocate buffer unless necessary. - virtual void AllocateBufferIfNeeded(); - - virtual size_t ElementSize() const { return 4; } - - bool is_output_tensor = false; - - absl::flat_hash_map metadata; -}; - -std::ostream& operator<<(std::ostream& os, const Tensor& tensor); - -// Channelwise Quantized. -struct QCTensor : public Tensor { - explicit QCTensor(DimsType in_dims, size_t dim_scale_if_any) - : Tensor(std::move(in_dims)), dim_scale(dim_scale_if_any) { - datatype = xnn_datatype_qcint8; - CHECK_LT(dim_scale, 4); - } - - void AllocateBufferIfNeeded() override; - size_t ElementSize() const override { return 1; } - - virtual absl::Status LoadFromFile(absl::string_view quantized_weight_filename, - absl::string_view scale_filename, - bool use_mmap, bool exact_match); - // Append kQuantizedScaleSuffix to use as scale filename. - absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap, - bool exact_match) override { - return LoadFromFile(file_path, - absl::StrCat(file_path, kQuantizedScaleSuffix), - use_mmap, exact_match); - } - - absl::Status DumpToFile(absl::string_view file_path) override; - - absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags) override; - - std::shared_ptr Transpose() override; - - absl::StatusOr> ConvertToF32() override; - - std::shared_ptr View(DimsType as_dims, - size_t dim_scale_if_any) override; - - std::shared_ptr scale_data; - // Index of the dimension to scale. - size_t dim_scale; -}; - -std::ostream& operator<<(std::ostream& os, const QCTensor& tensor); - -absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos); - -} // namespace xnn_utils -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_