Internal Change
PiperOrigin-RevId: 547299595
This commit is contained in:
		
							parent
							
								
									56bc019819
								
							
						
					
					
						commit
						aabf61f28d
					
				| 
						 | 
				
			
			@ -1 +0,0 @@
 | 
			
		|||
# Utilities needed to interacte with XNNPACK.
 | 
			
		||||
| 
						 | 
				
			
			@ -1,887 +0,0 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/log/check.h"
 | 
			
		||||
#include "absl/log/log.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "absl/time/clock.h"
 | 
			
		||||
#include "absl/time/time.h"
 | 
			
		||||
#include "absl/types/source_location.h"
 | 
			
		||||
#include "file/base/helpers.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
#include "third_party/XNNPACK/include/xnnpack.h"
 | 
			
		||||
#include "util/gtl/stl_logging.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// XNNPACK supports broadcasting, this function inferences the output shape
 | 
			
		||||
// based on input tensor shapes.
 | 
			
		||||
std::vector<size_t> OutDimsForElementwiseOp(const Tensor& lhs,
 | 
			
		||||
                                            const Tensor& rhs) {
 | 
			
		||||
  DCHECK(!lhs.dims.empty());
 | 
			
		||||
  DCHECK(!rhs.dims.empty());
 | 
			
		||||
  std::vector<size_t> lhs_dims_rev(lhs.dims.rbegin(), lhs.dims.rend());
 | 
			
		||||
  std::vector<size_t> rhs_dims_rev(rhs.dims.rbegin(), rhs.dims.rend());
 | 
			
		||||
  DCHECK([&]() -> bool {
 | 
			
		||||
    for (size_t i = 0; i < std::min(lhs_dims_rev.size(), rhs_dims_rev.size());
 | 
			
		||||
         ++i) {
 | 
			
		||||
      if ((lhs_dims_rev[i] != rhs_dims_rev[i]) && (lhs_dims_rev[i] != 1) &&
 | 
			
		||||
          (rhs_dims_rev[i] != 1)) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
  }()) << "lhs "
 | 
			
		||||
       << lhs.dims << " rhs " << rhs.dims;
 | 
			
		||||
  std::vector<size_t> out_dims(
 | 
			
		||||
      std::max(lhs_dims_rev.size(), rhs_dims_rev.size()));
 | 
			
		||||
  for (int i = 0; i < out_dims.size(); ++i) {
 | 
			
		||||
    if (lhs_dims_rev.size() <= i) {
 | 
			
		||||
      out_dims[i] = rhs_dims_rev[i];
 | 
			
		||||
    } else if (rhs_dims_rev.size() <= i) {
 | 
			
		||||
      out_dims[i] = lhs_dims_rev[i];
 | 
			
		||||
    } else {
 | 
			
		||||
      out_dims[i] = lhs_dims_rev[i] == 1 ? rhs_dims_rev[i] : lhs_dims_rev[i];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return std::vector<size_t>(out_dims.rbegin(), out_dims.rend());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// If out_id is invalid, we need to allocate tensor for intermediate result.
 | 
			
		||||
// Otherwise, set out_id in out_metadata.
 | 
			
		||||
absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph,
 | 
			
		||||
                                             uint32_t out_id,
 | 
			
		||||
                                             Tensor& out_metadata) {
 | 
			
		||||
  RET_CHECK_GT(out_metadata.dims.size(), 0);
 | 
			
		||||
  if (out_id == XNN_INVALID_VALUE_ID) {
 | 
			
		||||
    // The output is intermediate, thus allocate tensor.
 | 
			
		||||
    MP_RETURN_IF_ERROR(out_metadata.DefineAsIntermediateTensor(*subgraph));
 | 
			
		||||
  } else {
 | 
			
		||||
    out_metadata.tensor_id = out_id;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status MaybeAllocateIntermediateTensor(xnn_subgraph_t subgraph,
 | 
			
		||||
                                             Tensor& out_metadata) {
 | 
			
		||||
  return MaybeAllocateIntermediateTensor(subgraph, out_metadata.tensor_id,
 | 
			
		||||
                                         out_metadata);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status AllocateIntermediateTensor(xnn_subgraph_t subgraph,
 | 
			
		||||
                                        Tensor& out_metadata) {
 | 
			
		||||
  return MaybeAllocateIntermediateTensor(subgraph, XNN_INVALID_VALUE_ID,
 | 
			
		||||
                                         out_metadata);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 1.0/jax.nn.softplus(0.0) = 1.442695041
 | 
			
		||||
// scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1])
 | 
			
		||||
void SoftPlus(size_t cnt, const std::vector<size_t>& query_dims, float* weight,
 | 
			
		||||
              float* scale) {
 | 
			
		||||
  constexpr double r_softplus_0 = 1.442695041;
 | 
			
		||||
  // softplus(x) = np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)
 | 
			
		||||
  // scale = softplus(per_dim_scale) / (sqrt(input.dims[-1]) * softplus(0))
 | 
			
		||||
  const double r_softplus_0_over_sqrt_d =
 | 
			
		||||
      r_softplus_0 / std::sqrt(query_dims.back());
 | 
			
		||||
  for (int i = 0; i < cnt; ++i) {
 | 
			
		||||
    scale[i] = log1p(exp(-abs(weight[i]))) + fmax(weight[i], 0.0f);
 | 
			
		||||
    scale[i] *= r_softplus_0_over_sqrt_d;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build(
 | 
			
		||||
    std::unique_ptr<RuntimeConfigs> runtime_configs) {
 | 
			
		||||
  if (!runtime_configs) {
 | 
			
		||||
    runtime_configs = std::make_unique<RuntimeConfigs>();
 | 
			
		||||
    runtime_configs->xnn_num_threads = 1;
 | 
			
		||||
    runtime_configs->xnn_profile = false;
 | 
			
		||||
  }
 | 
			
		||||
  VLOG(2) << "XnnGraphBuilder::Build() building...";
 | 
			
		||||
  auto build_begin = absl::Now();
 | 
			
		||||
  RET_CHECK_EQ(xnn_status_success, xnn_initialize(nullptr));
 | 
			
		||||
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> output_tensors;
 | 
			
		||||
  {
 | 
			
		||||
    uint32_t cnt = input_tensors_.size();
 | 
			
		||||
    for (auto& t : interm_tensors_) {
 | 
			
		||||
      if (t->is_output_tensor) {
 | 
			
		||||
        RET_CHECK_EQ(t->tensor_id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
        t->tensor_id = cnt++;
 | 
			
		||||
        output_tensors.insert(t);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    for (auto& t : output_tensors) {
 | 
			
		||||
      interm_tensors_.erase(t);
 | 
			
		||||
    }
 | 
			
		||||
    for (auto& t : rope_weigths_) {
 | 
			
		||||
      interm_tensors_.erase(t);
 | 
			
		||||
      t->tensor_id = cnt++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  xnn_subgraph_t subgraph_ptr = nullptr;
 | 
			
		||||
  RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
               xnn_create_subgraph(
 | 
			
		||||
                   /*external_value_ids=*/input_tensors_.size() +
 | 
			
		||||
                       output_tensors.size() + rope_weigths_.size(),
 | 
			
		||||
                   /*flags=*/0, &subgraph_ptr));
 | 
			
		||||
  RET_CHECK_NE(subgraph_ptr, nullptr);
 | 
			
		||||
 | 
			
		||||
  XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph};
 | 
			
		||||
 | 
			
		||||
  for (auto& input : input_tensors_) {
 | 
			
		||||
    MP_RETURN_IF_ERROR(input->DefineAsInput(*subgraph));
 | 
			
		||||
  }
 | 
			
		||||
  for (auto& output : output_tensors) {
 | 
			
		||||
    MP_RETURN_IF_ERROR(output->DefineAsOutput(*subgraph));
 | 
			
		||||
  }
 | 
			
		||||
  {
 | 
			
		||||
    for (auto& t : rope_weigths_) {
 | 
			
		||||
      MP_RETURN_IF_ERROR(t->DefineRope(*subgraph));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (auto& [loc, step] : build_steps_) {
 | 
			
		||||
    if (auto s = step(subgraph.get()); !s.ok()) {
 | 
			
		||||
      s.AddSourceLocation(loc);
 | 
			
		||||
      return s;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  XnnGraph result(std::move(subgraph), std::move(runtime_configs));
 | 
			
		||||
  result.input_tensors_ = std::move(input_tensors_);
 | 
			
		||||
  result.output_tensors_ = std::move(output_tensors);
 | 
			
		||||
  result.interm_tensors_ = std::move(interm_tensors_);
 | 
			
		||||
 | 
			
		||||
  VLOG(2) << "XnnGraphBuilder::Build() creating runtime...";
 | 
			
		||||
  auto create_begin = absl::Now();
 | 
			
		||||
  MP_RETURN_IF_ERROR(result.CreateRuntime());
 | 
			
		||||
  VLOG(2) << "XnnGraphBuilder::Build() setting up runtime...";
 | 
			
		||||
  auto setup_begin = absl::Now();
 | 
			
		||||
  MP_RETURN_IF_ERROR(result.SetupRuntime());
 | 
			
		||||
 | 
			
		||||
  auto end = absl::Now();
 | 
			
		||||
  VLOG(2) << "XnnGraphBuilder::Build() done build, Total " << end - build_begin
 | 
			
		||||
          << ", create runtime " << setup_begin - create_begin
 | 
			
		||||
          << ", setup runtime " << end - setup_begin;
 | 
			
		||||
  return std::make_unique<XnnGraph>(std::move(result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewInput(
 | 
			
		||||
    Tensor::DimsType dims, absl::SourceLocation loc) {
 | 
			
		||||
  auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
 | 
			
		||||
  t->AllocateBufferIfNeeded();
 | 
			
		||||
  t->tensor_id = input_tensors_.size();
 | 
			
		||||
  input_tensors_.insert(t);
 | 
			
		||||
  return t;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewWeight(
 | 
			
		||||
    absl::string_view file_path, Tensor::DimsType dims,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto t, NewWeight(std::move(dims)));
 | 
			
		||||
  MP_RETURN_IF_ERROR(t->LoadFromFile(file_path));
 | 
			
		||||
  return t;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::NewWeight(
 | 
			
		||||
    Tensor::DimsType dims, absl::SourceLocation loc) {
 | 
			
		||||
  auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
 | 
			
		||||
  NewWeight(t, loc);
 | 
			
		||||
  return t;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void XnnGraphBuilder::NewWeight(std::shared_ptr<Tensor> t,
 | 
			
		||||
                                absl::SourceLocation loc) {
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         if (interm_tensors_.contains(t)) {
 | 
			
		||||
           MP_RETURN_IF_ERROR(t->DefineWeight(*subgraph));
 | 
			
		||||
         }
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  interm_tensors_.insert(t);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::IntermediateTensor(
 | 
			
		||||
    Tensor::DimsType dims, absl::SourceLocation loc) {
 | 
			
		||||
  auto t = std::make_shared<Tensor>(std::move(dims), data_type_);
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, t](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         // Could be moved to output tensors, thus need check.
 | 
			
		||||
         if (interm_tensors_.contains(t)) {
 | 
			
		||||
           return AllocateIntermediateTensor(subgraph, *t);
 | 
			
		||||
         }
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  interm_tensors_.insert(t);
 | 
			
		||||
  return t;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Reshape(
 | 
			
		||||
    std::shared_ptr<Tensor> input, Tensor::DimsType new_dims,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims)));
 | 
			
		||||
  RET_CHECK_EQ(input->num_elements, output->num_elements)
 | 
			
		||||
      << "otherwise reshape does not make sense.";
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
 | 
			
		||||
             subgraph, output->tensor_id, *output));
 | 
			
		||||
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_static_reshape(
 | 
			
		||||
                          subgraph, output->dims.size(), output->dims.data(),
 | 
			
		||||
                          input->tensor_id, output->tensor_id, /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::FullConn(
 | 
			
		||||
    std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
    std::shared_ptr<Tensor> bias, FullConnParams params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  const auto& input_dim = input->dims;
 | 
			
		||||
  const auto& weight_dim = weight->dims;
 | 
			
		||||
  DCHECK_GT(input_dim.size(), 1);
 | 
			
		||||
  DCHECK_GE(weight_dim.size(), 2);
 | 
			
		||||
  if (weight_dim.size() == 3) {
 | 
			
		||||
    RET_CHECK_EQ(weight_dim[0], 1);
 | 
			
		||||
  } else if (weight_dim.size() == 4) {
 | 
			
		||||
    RET_CHECK_EQ(weight_dim[0], 1);
 | 
			
		||||
    RET_CHECK_EQ(weight_dim[1], 1);
 | 
			
		||||
  }
 | 
			
		||||
  if (bias) {
 | 
			
		||||
    RET_CHECK_LE(bias->dims.size(), 1);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor::DimsType out_dims = input_dim;
 | 
			
		||||
  // Not considering reshape 2D
 | 
			
		||||
  if (params.transpose) {
 | 
			
		||||
    RET_CHECK_EQ(weight_dim.size(), 2) << "otherwise change following line";
 | 
			
		||||
    RET_CHECK_EQ(input_dim.back(), *(weight_dim.end() - 2));
 | 
			
		||||
    out_dims.back() = weight_dim.back();
 | 
			
		||||
  } else {
 | 
			
		||||
    RET_CHECK_EQ(input_dim.back(), weight_dim.back());
 | 
			
		||||
    out_dims.pop_back();
 | 
			
		||||
    for (size_t i = 0; i < weight_dim.size() - 1; ++i) {
 | 
			
		||||
      // NHD . BTD -> NHBT
 | 
			
		||||
      out_dims.push_back(weight_dim[i]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(out_dims)));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, input, weight, bias, params,
 | 
			
		||||
        output](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
 | 
			
		||||
             subgraph, output->tensor_id, *output));
 | 
			
		||||
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_fully_connected(
 | 
			
		||||
                 subgraph, params.out_min, params.out_max, input->tensor_id,
 | 
			
		||||
                 weight->tensor_id,
 | 
			
		||||
                 bias ? bias->tensor_id : XNN_INVALID_VALUE_ID,
 | 
			
		||||
                 output->tensor_id,
 | 
			
		||||
                 /*flags=*/params.transpose ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0));
 | 
			
		||||
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Permute(
 | 
			
		||||
    std::shared_ptr<Tensor> input, Tensor::DimsType permute,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  RET_CHECK_EQ(input->dims.size(), permute.size());
 | 
			
		||||
  const auto& old_dims = input->dims;
 | 
			
		||||
  std::vector<size_t> new_dims;
 | 
			
		||||
  for (size_t i = 0; i < permute.size(); ++i) {
 | 
			
		||||
    new_dims.push_back(old_dims[permute[i]]);
 | 
			
		||||
  }
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(std::move(new_dims)));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, permute, input, output](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_static_transpose(
 | 
			
		||||
                          subgraph, permute.size(), permute.data(),
 | 
			
		||||
                          input->tensor_id, output->tensor_id, /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Square(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
 | 
			
		||||
             subgraph, output->tensor_id, *output));
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_square(subgraph, input->tensor_id, output->tensor_id,
 | 
			
		||||
                               /*flags=*/0));
 | 
			
		||||
         return absl::Status();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Softmax(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
 | 
			
		||||
             subgraph, output->tensor_id, *output));
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_softmax(subgraph, input->tensor_id, output->tensor_id,
 | 
			
		||||
                                /*flags=*/0));
 | 
			
		||||
         return absl::Status();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::SquareRoot(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, output, input](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
 | 
			
		||||
             subgraph, output->tensor_id, *output));
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_square_root(subgraph, input->tensor_id,
 | 
			
		||||
                                             output->tensor_id,
 | 
			
		||||
                                             /*flags=*/0));
 | 
			
		||||
         return absl::Status();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::AvgLastDim(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto before_reshape,
 | 
			
		||||
                   IntermediateTensor(Tensor::DimsType{input->dims.begin(),
 | 
			
		||||
                                                       input->dims.end() - 1}));
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, input, before_reshape](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(
 | 
			
		||||
             subgraph, before_reshape->tensor_id, *before_reshape));
 | 
			
		||||
         size_t reduction_axis = input->dims.size() - 1;
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_static_mean(subgraph, 1, &reduction_axis,
 | 
			
		||||
                                    input->tensor_id, before_reshape->tensor_id,
 | 
			
		||||
                                    /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  Tensor::DimsType new_dims = input->dims;
 | 
			
		||||
  new_dims.back() = 1;
 | 
			
		||||
  return Reshape(before_reshape, std::move(new_dims));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rms(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto sqr_out, Square(input, loc));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto mean_out, AvgLastDim(sqr_out, loc));
 | 
			
		||||
 | 
			
		||||
  return SquareRoot(mean_out, loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::RmsNorm(
 | 
			
		||||
    std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto rms_out, Rms(input));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto clamped_rms, Clamp(rms_out, {.out_min = 1e-6}));
 | 
			
		||||
 | 
			
		||||
  // div_out = input / rms
 | 
			
		||||
  ASSIGN_OR_RETURN(auto div_out, ElementDiv(input, clamped_rms));
 | 
			
		||||
 | 
			
		||||
  // div_out * (1 + scale) = div_out + div_out * scale
 | 
			
		||||
  ASSIGN_OR_RETURN(auto normed_div_out, ElementMul(div_out, scale));
 | 
			
		||||
 | 
			
		||||
  return ElementAdd(div_out, normed_div_out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
 | 
			
		||||
    std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
 | 
			
		||||
  MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
 | 
			
		||||
 | 
			
		||||
  return ElementAdd(lhs, rhs_tensor, params, loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementAdd(
 | 
			
		||||
    std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
 | 
			
		||||
    ClampParams params, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output,
 | 
			
		||||
                   IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, lhs, rhs, output,
 | 
			
		||||
        params](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_add2(subgraph, params.out_min, params.out_max,
 | 
			
		||||
                                      lhs->tensor_id, rhs->tensor_id,
 | 
			
		||||
                                      output->tensor_id, /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementMul(
 | 
			
		||||
    std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
 | 
			
		||||
  MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
 | 
			
		||||
 | 
			
		||||
  return ElementMul(lhs, rhs_tensor, params, loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementMul(
 | 
			
		||||
    std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
 | 
			
		||||
    ClampParams params, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output,
 | 
			
		||||
                   IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, lhs, rhs, output,
 | 
			
		||||
        params](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_multiply2(subgraph, params.out_min, params.out_max,
 | 
			
		||||
                                  lhs->tensor_id, rhs->tensor_id,
 | 
			
		||||
                                  output->tensor_id, /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementDiv(
 | 
			
		||||
    std::shared_ptr<Tensor> lhs, float rhs, ClampParams params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto rhs_tensor, NewWeight({1}));
 | 
			
		||||
  MP_RETURN_IF_ERROR(rhs_tensor->LoadFromVec(std::vector<float>({rhs})));
 | 
			
		||||
 | 
			
		||||
  return ElementDiv(lhs, rhs_tensor, params, loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::ElementDiv(
 | 
			
		||||
    std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
 | 
			
		||||
    ClampParams params, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output,
 | 
			
		||||
                   IntermediateTensor(OutDimsForElementwiseOp(*lhs, *rhs)));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, lhs, rhs, output,
 | 
			
		||||
        params](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_divide(subgraph, params.out_min, params.out_max,
 | 
			
		||||
                               lhs->tensor_id, rhs->tensor_id,
 | 
			
		||||
                               output->tensor_id, /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: write an op?
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::PerDimScale(
 | 
			
		||||
    std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> per_dim_scale,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  // input: B T N H
 | 
			
		||||
  // 1/softplus(0) = 1.442695041
 | 
			
		||||
  // scale = softplus(w) * 1.442695041 / np.sqrt(query.shape[-1])
 | 
			
		||||
  // query = query * scale
 | 
			
		||||
  const auto& input_dim = input->dims;
 | 
			
		||||
  DCHECK_GE(input_dim.size(), 1);
 | 
			
		||||
  const size_t H = input_dim.back();
 | 
			
		||||
 | 
			
		||||
  if (!per_dim_scale_cache_.contains(H) ||
 | 
			
		||||
      !per_dim_scale_cache_[H].contains(per_dim_scale.get())) {
 | 
			
		||||
    ASSIGN_OR_RETURN(auto cached_pds, NewWeight(per_dim_scale->dims));
 | 
			
		||||
 | 
			
		||||
    auto* pds_in = static_cast<float*>(per_dim_scale->Data());
 | 
			
		||||
    std::vector<float> pds_scaled(per_dim_scale->num_elements);
 | 
			
		||||
    SoftPlus(per_dim_scale->num_elements, input_dim, pds_in, pds_scaled.data());
 | 
			
		||||
    MP_RETURN_IF_ERROR(cached_pds->LoadFromVec(std::move(pds_scaled)));
 | 
			
		||||
    per_dim_scale_cache_[H][per_dim_scale.get()] = cached_pds;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return ElementMul(input, per_dim_scale_cache_[H][per_dim_scale.get()]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Rope(
 | 
			
		||||
    std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  // TODO: seg_pos should not be weight.
 | 
			
		||||
  rope_weigths_.insert(segment_pos);
 | 
			
		||||
 | 
			
		||||
  const auto& input_dim = input->dims;
 | 
			
		||||
  const auto& segment_pos_dim = segment_pos->dims;
 | 
			
		||||
  // B T N H
 | 
			
		||||
  RET_CHECK_EQ(input_dim.size(), 4) << "xnn requirement";
 | 
			
		||||
  // S H
 | 
			
		||||
  RET_CHECK_EQ(segment_pos_dim.size(), 2) << "xnn requirement";
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(input_dim));
 | 
			
		||||
 | 
			
		||||
  const auto input_seq_size = input_dim[1];
 | 
			
		||||
  RET_CHECK_LE(input_seq_size, segment_pos_dim[0]);
 | 
			
		||||
  const auto head_dim_H = input_dim[3];
 | 
			
		||||
  RET_CHECK_EQ(head_dim_H, segment_pos_dim[1]);
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, input, output, segment_pos,
 | 
			
		||||
        input_seq_size](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
         RET_CHECK_EQ(
 | 
			
		||||
             xnn_status_success,
 | 
			
		||||
             xnn_define_rope(subgraph, input_seq_size, input->tensor_id,
 | 
			
		||||
                             segment_pos->tensor_id, output->tensor_id,
 | 
			
		||||
                             /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::BatchMatMul(
 | 
			
		||||
    std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
    FullConnParams params, absl::SourceLocation loc) {
 | 
			
		||||
  const auto& lhs_dim = input->dims;
 | 
			
		||||
  const auto& rhs_dim = weight->dims;
 | 
			
		||||
 | 
			
		||||
  // [B, N, T, H] . [B, N, S, H], N == 12, B == 1
 | 
			
		||||
  DCHECK_EQ(lhs_dim.size(), 4);
 | 
			
		||||
  DCHECK_EQ(rhs_dim.size(), 4);
 | 
			
		||||
  DCHECK_EQ(lhs_dim.back(), rhs_dim.back());
 | 
			
		||||
  DCHECK_EQ(lhs_dim.back(), rhs_dim.back());
 | 
			
		||||
  constexpr size_t num_slices = 12;
 | 
			
		||||
  DCHECK_EQ(lhs_dim[1], num_slices);
 | 
			
		||||
  DCHECK_EQ(rhs_dim[1], num_slices);
 | 
			
		||||
  const size_t S = rhs_dim[2];
 | 
			
		||||
  const size_t T = lhs_dim[2];
 | 
			
		||||
  const size_t batch_size = lhs_dim[0] * lhs_dim[1];
 | 
			
		||||
  DCHECK_EQ(batch_size, rhs_dim[0] * rhs_dim[1]);
 | 
			
		||||
  DCHECK_EQ(batch_size, 12);
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor({1, 12, T, S}));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [input, output, weight](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_batch_matrix_multiply(
 | 
			
		||||
                          subgraph, input->tensor_id, weight->tensor_id,
 | 
			
		||||
                          output->tensor_id, /*flags=*/0));
 | 
			
		||||
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Tanh(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc, [this, input, output](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_tanh(subgraph, input->tensor_id,
 | 
			
		||||
                                      output->tensor_id, /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::CapTanh(
 | 
			
		||||
    std::shared_ptr<Tensor> input, float cap, absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto div, ElementDiv(input, cap));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto tanh, Tanh(div));
 | 
			
		||||
  return ElementMul(tanh, cap);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::DotAttention(
 | 
			
		||||
    std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
 | 
			
		||||
    std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
 | 
			
		||||
    std::shared_ptr<Tensor> per_dim_scale, absl::SourceLocation loc) {
 | 
			
		||||
  // BTNH
 | 
			
		||||
  ASSIGN_OR_RETURN(auto query_after_scale,
 | 
			
		||||
                   PerDimScale(query_proj, per_dim_scale));
 | 
			
		||||
 | 
			
		||||
  // Dot similarity
 | 
			
		||||
  // BTNH -> BNTH
 | 
			
		||||
  ASSIGN_OR_RETURN(auto query_permuted,
 | 
			
		||||
                   Permute(query_after_scale, {0, 2, 1, 3}));
 | 
			
		||||
  // BSNH -> BNSH
 | 
			
		||||
  ASSIGN_OR_RETURN(auto key_permuted, Permute(key_proj, {0, 2, 1, 3}));
 | 
			
		||||
  // einsum(BNTH.BNSH -> BNTS)
 | 
			
		||||
  ASSIGN_OR_RETURN(auto logits, BatchMatMul(query_permuted, key_permuted));
 | 
			
		||||
 | 
			
		||||
  // Cap, mask
 | 
			
		||||
  ASSIGN_OR_RETURN(auto cap_logits, CapTanh(logits, 50));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto padded_logits, ElementAdd(atten_mask, cap_logits));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto probs, Softmax(padded_logits));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto value_permuted, Permute(value_proj, {0, 2, 3, 1}));
 | 
			
		||||
 | 
			
		||||
  // Outcome
 | 
			
		||||
  // BNTS.BNHS -> BNTH
 | 
			
		||||
  ASSIGN_OR_RETURN(auto outcome_before_permute,
 | 
			
		||||
                   BatchMatMul(probs, value_permuted));
 | 
			
		||||
  // [B, N, T, H] -> BTNH
 | 
			
		||||
  return Permute(outcome_before_permute, {0, 2, 1, 3});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::SelfAttentionProj(
 | 
			
		||||
    std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  const auto& input_dim = input->dims;
 | 
			
		||||
  const auto& weight_dim = weight->dims;
 | 
			
		||||
  size_t N = 0, H = 0;
 | 
			
		||||
  RET_CHECK_EQ(input_dim.size(), 3) << "BTD";
 | 
			
		||||
 | 
			
		||||
  std::optional<int> reshaped_N =
 | 
			
		||||
      weight->GetMetadata(kKeySelfAttentionReshapedWeight);
 | 
			
		||||
  RET_CHECK(reshaped_N && *reshaped_N)
 | 
			
		||||
      << "We rely on " << kKeySelfAttentionReshapedWeight << " to get N";
 | 
			
		||||
  RET_CHECK_EQ(weight_dim.size(), 2) << "NH,D";
 | 
			
		||||
  N = *reshaped_N;
 | 
			
		||||
  H = weight_dim[0] / N;
 | 
			
		||||
 | 
			
		||||
  // out: B,T,NH
 | 
			
		||||
  ASSIGN_OR_RETURN(auto proj, MatMul(input, weight));
 | 
			
		||||
 | 
			
		||||
  // B,T,NH -> B,T,N,H
 | 
			
		||||
  return Reshape(proj, {input_dim[0], input_dim[1], N, H});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status XnnGraph::CreateRuntime() {
 | 
			
		||||
  RET_CHECK_EQ(runtime_.get(), nullptr);
 | 
			
		||||
  xnn_runtime_t runtime_ptr = nullptr;
 | 
			
		||||
  uint32_t flags = 0;
 | 
			
		||||
  if (runtime_configs_->xnn_profile) {
 | 
			
		||||
    flags |= XNN_FLAG_BASIC_PROFILING;
 | 
			
		||||
 | 
			
		||||
    if (!runtime_configs_->xnn_profile_csv.empty()) {
 | 
			
		||||
      MP_RETURN_IF_ERROR(file::SetContents(runtime_configs_->xnn_profile_csv,
 | 
			
		||||
                                           "node_id; time(us); op_name\n",
 | 
			
		||||
                                           file::Defaults()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  pthreadpool_t threadpool =
 | 
			
		||||
      pthreadpool_create(runtime_configs_->xnn_num_threads);
 | 
			
		||||
  threadpool_ = XnnThreadpoolPtr{threadpool, pthreadpool_destroy};
 | 
			
		||||
 | 
			
		||||
  RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
               xnn_create_runtime_v2(owned_subgraph_.get(), threadpool, flags,
 | 
			
		||||
                                     &runtime_ptr));
 | 
			
		||||
  RET_CHECK_NE(runtime_ptr, nullptr);
 | 
			
		||||
  runtime_ = XnnRuntimePtr{runtime_ptr, xnn_delete_runtime};
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status XnnGraph::SetupRuntime() {
 | 
			
		||||
  {
 | 
			
		||||
    VLOG(3) << "input size " << input_tensors_.size();
 | 
			
		||||
    VLOG(3) << "output size " << output_tensors_.size();
 | 
			
		||||
    VLOG(3) << "rope size " << rope_weigths_.size();
 | 
			
		||||
    externals_.clear();
 | 
			
		||||
    // Init external
 | 
			
		||||
    for (const auto& input : input_tensors_) {
 | 
			
		||||
      VLOG(3) << "input id " << input->tensor_id;
 | 
			
		||||
      externals_.push_back(xnn_external_value{input->tensor_id, input->Data()});
 | 
			
		||||
    }
 | 
			
		||||
    for (const auto& output : output_tensors_) {
 | 
			
		||||
      VLOG(3) << "output id " << output->tensor_id;
 | 
			
		||||
      externals_.push_back(
 | 
			
		||||
          xnn_external_value{output->tensor_id, output->Data()});
 | 
			
		||||
    }
 | 
			
		||||
    for (const auto& t : rope_weigths_) {
 | 
			
		||||
      VLOG(3) << "rope id " << t->tensor_id;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  RET_CHECK_EQ(
 | 
			
		||||
      xnn_status_success,
 | 
			
		||||
      xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status XnnGraph::Run() {
 | 
			
		||||
  RET_CHECK(runtime_);
 | 
			
		||||
 | 
			
		||||
  RET_CHECK_EQ(xnn_status_success, xnn_invoke_runtime(runtime_.get()));
 | 
			
		||||
 | 
			
		||||
  if (runtime_configs_->xnn_profile) {
 | 
			
		||||
    size_t required_size = 0;
 | 
			
		||||
 | 
			
		||||
    // xnn_get_runtime_profiling_info is called twice. The first time it sets
 | 
			
		||||
    // required_size to the required size of the buffer to store the result and
 | 
			
		||||
    // returns xnn_status_out_of_memory. The second time it writes the result to
 | 
			
		||||
    // the buffer provided that the buffer is large enough and returns
 | 
			
		||||
    // xnn_status_success.
 | 
			
		||||
    xnn_status status = xnn_get_runtime_profiling_info(
 | 
			
		||||
        runtime_.get(), xnn_profile_info_operator_name, /*param_value_size*/ 0,
 | 
			
		||||
        /*param_value*/ nullptr, &required_size);
 | 
			
		||||
    std::vector<char> operator_names;
 | 
			
		||||
    if (status == xnn_status_out_of_memory) {
 | 
			
		||||
      operator_names.resize(required_size);
 | 
			
		||||
      status = xnn_get_runtime_profiling_info(
 | 
			
		||||
          runtime_.get(), xnn_profile_info_operator_name, operator_names.size(),
 | 
			
		||||
          operator_names.data(), &required_size);
 | 
			
		||||
    }
 | 
			
		||||
    RET_CHECK_EQ(status, xnn_status_success);
 | 
			
		||||
    size_t num_operators;
 | 
			
		||||
    status = xnn_get_runtime_profiling_info(
 | 
			
		||||
        runtime_.get(), xnn_profile_info_num_operators, sizeof(num_operators),
 | 
			
		||||
        &num_operators, &required_size);
 | 
			
		||||
    RET_CHECK_EQ(status, xnn_status_success);
 | 
			
		||||
    status = xnn_get_runtime_profiling_info(
 | 
			
		||||
        runtime_.get(), xnn_profile_info_operator_timing,
 | 
			
		||||
        /*param_value_size*/ 0,
 | 
			
		||||
        /*param_value*/ nullptr, &required_size);
 | 
			
		||||
    std::vector<uint64_t> operator_timings;
 | 
			
		||||
    if (status == xnn_status_out_of_memory) {
 | 
			
		||||
      operator_timings.resize(required_size / sizeof(uint64_t));
 | 
			
		||||
      status = xnn_get_runtime_profiling_info(
 | 
			
		||||
          runtime_.get(), xnn_profile_info_operator_timing,
 | 
			
		||||
          operator_timings.size() * sizeof(uint64_t), operator_timings.data(),
 | 
			
		||||
          &required_size);
 | 
			
		||||
    }
 | 
			
		||||
    RET_CHECK_EQ(status, xnn_status_success);
 | 
			
		||||
    const char* operator_name = nullptr;
 | 
			
		||||
    size_t name_len = 0;
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    for (size_t node_index = 0; node_index < num_operators; ++node_index) {
 | 
			
		||||
      operator_name = &operator_names[name_len];
 | 
			
		||||
      name_len += strlen(operator_name) + 1;
 | 
			
		||||
      VLOG(2) << "XnnGraphBuilder::Profile() node_index: " << node_index
 | 
			
		||||
              << ", time: " << operator_timings[node_index] << " us, "
 | 
			
		||||
              << operator_name << "\n";
 | 
			
		||||
      if (!runtime_configs_->xnn_profile_csv.empty()) {
 | 
			
		||||
        // Use ';' instead of ',' because operator_name contains comma.
 | 
			
		||||
        ss << node_index << "; " << operator_timings[node_index] << "; "
 | 
			
		||||
           << operator_name << "\n";
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (!runtime_configs_->xnn_profile_csv.empty()) {
 | 
			
		||||
      MP_RETURN_IF_ERROR(file::AppendStringToFile(
 | 
			
		||||
          runtime_configs_->xnn_profile_csv, ss.str(), file::Defaults()));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Clamp(
 | 
			
		||||
    std::shared_ptr<Tensor> input, ClampParams params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto output, IntermediateTensor(input->dims));
 | 
			
		||||
 | 
			
		||||
  build_steps_.push_back(
 | 
			
		||||
      {loc,
 | 
			
		||||
       [this, input, output, params](xnn_subgraph_t subgraph) -> absl::Status {
 | 
			
		||||
         MP_RETURN_IF_ERROR(MaybeAllocateIntermediateTensor(subgraph, *output));
 | 
			
		||||
 | 
			
		||||
         RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
                      xnn_define_clamp(subgraph, params.out_min, params.out_max,
 | 
			
		||||
                                       input->tensor_id, output->tensor_id,
 | 
			
		||||
                                       /*flags=*/0));
 | 
			
		||||
         return absl::OkStatus();
 | 
			
		||||
       }});
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> XnnGraphBuilder::Gelu(
 | 
			
		||||
    std::shared_ptr<Tensor> input, absl::SourceLocation loc) {
 | 
			
		||||
  // x^2
 | 
			
		||||
  ASSIGN_OR_RETURN(auto sqr_out, Square(input));
 | 
			
		||||
 | 
			
		||||
  // 0.044715 * x^2
 | 
			
		||||
  ASSIGN_OR_RETURN(auto sqr_4471, ElementMul(sqr_out, 0.044715));
 | 
			
		||||
 | 
			
		||||
  // 1 + 0.044715 * x^2
 | 
			
		||||
  ASSIGN_OR_RETURN(auto sqr_4471_1, ElementAdd(sqr_4471, 1.0f));
 | 
			
		||||
 | 
			
		||||
  // x + 0.044715 * x^3
 | 
			
		||||
  ASSIGN_OR_RETURN(auto x_cube_4471, ElementMul(sqr_4471_1, input));
 | 
			
		||||
 | 
			
		||||
  constexpr float sqrt_2_over_pi = 0.7978845608;
 | 
			
		||||
  ASSIGN_OR_RETURN(auto sqrt_2_over_pi_x_cube_4471,
 | 
			
		||||
                   ElementMul(x_cube_4471, sqrt_2_over_pi));
 | 
			
		||||
 | 
			
		||||
  // tanh(x + 0.044715 * x^3)
 | 
			
		||||
  ASSIGN_OR_RETURN(auto tanh_x_cube_4471, Tanh(sqrt_2_over_pi_x_cube_4471));
 | 
			
		||||
 | 
			
		||||
  // 1 + tanh(x + 0.044715 * x^3)
 | 
			
		||||
  ASSIGN_OR_RETURN(auto tanh_x_cube_4471_1, ElementAdd(tanh_x_cube_4471, 1.0f));
 | 
			
		||||
 | 
			
		||||
  // 0.5 * (1 + [tanh(x + 0.044715 * x^3)])
 | 
			
		||||
  ASSIGN_OR_RETURN(auto cdf, ElementMul(tanh_x_cube_4471_1, 0.5));
 | 
			
		||||
 | 
			
		||||
  return ElementMul(input, cdf);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -1,288 +0,0 @@
 | 
			
		|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
 | 
			
		||||
 | 
			
		||||
#include <sys/types.h>
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <initializer_list>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/container/flat_hash_map.h"
 | 
			
		||||
#include "absl/container/flat_hash_set.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "absl/types/source_location.h"
 | 
			
		||||
#include "file/base/helpers.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
#include "third_party/XNNPACK/include/xnnpack.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
using XnnSubgraphPtr =
 | 
			
		||||
    std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)>;
 | 
			
		||||
using XnnRuntimePtr =
 | 
			
		||||
    std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>;
 | 
			
		||||
using XnnThreadpoolPtr =
 | 
			
		||||
    std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)>;
 | 
			
		||||
 | 
			
		||||
struct ClampParams {
 | 
			
		||||
  float out_min = -std::numeric_limits<float>::infinity();
 | 
			
		||||
  float out_max = std::numeric_limits<float>::infinity();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct FullConnParams : public ClampParams {
 | 
			
		||||
  bool transpose = false;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct RuntimeConfigs {
 | 
			
		||||
  bool xnn_profile;
 | 
			
		||||
  std::string xnn_profile_csv;
 | 
			
		||||
  size_t xnn_num_threads;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class XnnGraph;
 | 
			
		||||
 | 
			
		||||
// XnnGraphBuilder is used to construct XnnGraph (through Build()). Once a
 | 
			
		||||
// XnnGraph is constructed, it can run for multiple times.
 | 
			
		||||
class XnnGraphBuilder {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr absl::string_view kKeySelfAttentionReshapedWeight{
 | 
			
		||||
      "self_attention_reshaped_weight_N"};
 | 
			
		||||
 | 
			
		||||
  explicit XnnGraphBuilder(xnn_datatype data_type = xnn_datatype_fp32)
 | 
			
		||||
      : data_type_(data_type) {}
 | 
			
		||||
  virtual ~XnnGraphBuilder() = default;
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::unique_ptr<XnnGraph>> Build(
 | 
			
		||||
      std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
 | 
			
		||||
 | 
			
		||||
  // New input or output tensor.
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> NewInput(
 | 
			
		||||
      Tensor::DimsType dims,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  // New static weight, populate value before Build()
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> NewWeight(
 | 
			
		||||
      Tensor::DimsType dims,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> NewWeight(
 | 
			
		||||
      absl::string_view file_path, Tensor::DimsType dims,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
  void NewWeight(std::shared_ptr<Tensor> t,
 | 
			
		||||
                 absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  // Element wise square.
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Square(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> SquareRoot(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Gelu(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Clamp(
 | 
			
		||||
      std::shared_ptr<Tensor> input, ClampParams params,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Tanh(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  // logits = cap * jnp.tanh(logits / cap)
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> CapTanh(
 | 
			
		||||
      std::shared_ptr<Tensor> input, float cap,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  // Average over last dimension, keep num of dims same.
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> AvgLastDim(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Rms(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> RmsNorm(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> scale,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Reshape(
 | 
			
		||||
      std::shared_ptr<Tensor> input, Tensor::DimsType new_dims,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Permute(
 | 
			
		||||
      std::shared_ptr<Tensor> input, Tensor::DimsType permute,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  // input: [B * I]
 | 
			
		||||
  // filter: [O * I], [I * O] if transpose
 | 
			
		||||
  // return: [B * O]
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> MatMul(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current()) {
 | 
			
		||||
    return MatMul(input, weight, FullConnParams(), loc);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> MatMul(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
      FullConnParams params,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current()) {
 | 
			
		||||
    return FullConn(input, weight, nullptr, params, loc);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> BatchMatMul(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
      FullConnParams params = FullConnParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> FullConn(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
      std::shared_ptr<Tensor> bias,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current()) {
 | 
			
		||||
    return FullConn(input, weight, bias, FullConnParams(), loc);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> FullConn(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
      std::shared_ptr<Tensor> bias, FullConnParams params,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Softmax(
 | 
			
		||||
      std::shared_ptr<Tensor> input,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionProj(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> weight,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ElementAdd(
 | 
			
		||||
      std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
 | 
			
		||||
      ClampParams params = ClampParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ElementAdd(
 | 
			
		||||
      std::shared_ptr<Tensor> lhs, float rhs,
 | 
			
		||||
      ClampParams params = ClampParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ElementMul(
 | 
			
		||||
      std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
 | 
			
		||||
      ClampParams params = ClampParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ElementMul(
 | 
			
		||||
      std::shared_ptr<Tensor> lhs, float rhs,
 | 
			
		||||
      ClampParams params = ClampParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ElementDiv(
 | 
			
		||||
      std::shared_ptr<Tensor> lhs, std::shared_ptr<Tensor> rhs,
 | 
			
		||||
      ClampParams params = ClampParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ElementDiv(
 | 
			
		||||
      std::shared_ptr<Tensor> lhs, float rhs,
 | 
			
		||||
      ClampParams params = ClampParams(),
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> Rope(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> segment_pos,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> PerDimScale(
 | 
			
		||||
      std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> per_dim_scale,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> DotAttention(
 | 
			
		||||
      std::shared_ptr<Tensor> query_proj, std::shared_ptr<Tensor> key_proj,
 | 
			
		||||
      std::shared_ptr<Tensor> value_proj, std::shared_ptr<Tensor> atten_mask,
 | 
			
		||||
      std::shared_ptr<Tensor> per_dim_scale,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> IntermediateTensor(
 | 
			
		||||
      Tensor::DimsType dims,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  const xnn_datatype data_type_;
 | 
			
		||||
 | 
			
		||||
  std::vector<std::pair<absl::SourceLocation,
 | 
			
		||||
                        std::function<absl::Status(xnn_subgraph_t)>>>
 | 
			
		||||
      build_steps_;
 | 
			
		||||
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> input_tensors_;
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
 | 
			
		||||
 | 
			
		||||
  // TODO: fix this.
 | 
			
		||||
  // This is sort of bug that the weights used for rope has to be defined with
 | 
			
		||||
  // EXTERNAL flag, but with id out of the external range.
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weigths_;
 | 
			
		||||
 | 
			
		||||
  // Caches
 | 
			
		||||
  absl::flat_hash_map<
 | 
			
		||||
      size_t /*dim*/,
 | 
			
		||||
      absl::flat_hash_map<const Tensor* /*scale*/, std::shared_ptr<Tensor>>>
 | 
			
		||||
      per_dim_scale_cache_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class XnnGraph {
 | 
			
		||||
 public:
 | 
			
		||||
  XnnGraph(XnnSubgraphPtr subgraph,
 | 
			
		||||
           std::unique_ptr<RuntimeConfigs> runtime_configs)
 | 
			
		||||
      : owned_subgraph_(std::move(subgraph)),
 | 
			
		||||
        runtime_configs_(std::move(runtime_configs)) {
 | 
			
		||||
    DCHECK(runtime_configs_);
 | 
			
		||||
  }
 | 
			
		||||
  XnnGraph(XnnGraph&& other) = default;
 | 
			
		||||
  virtual ~XnnGraph() = default;
 | 
			
		||||
 | 
			
		||||
  // xnn_subgraph should be created with same size.
 | 
			
		||||
  virtual absl::Status Run();
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  friend class XnnGraphBuilder;
 | 
			
		||||
 | 
			
		||||
  absl::Status CreateRuntime();
 | 
			
		||||
  absl::Status SetupRuntime();
 | 
			
		||||
 | 
			
		||||
  XnnSubgraphPtr owned_subgraph_;
 | 
			
		||||
 | 
			
		||||
  absl::flat_hash_map<size_t, Tensor> avg_cache_;
 | 
			
		||||
  absl::flat_hash_map<size_t, Tensor> cap_tanh_cache_;
 | 
			
		||||
 | 
			
		||||
  // Runtime
 | 
			
		||||
  std::unique_ptr<RuntimeConfigs> runtime_configs_;
 | 
			
		||||
  XnnRuntimePtr runtime_{nullptr, xnn_delete_runtime};
 | 
			
		||||
  std::vector<xnn_external_value> externals_;
 | 
			
		||||
 | 
			
		||||
  XnnThreadpoolPtr threadpool_{nullptr, pthreadpool_destroy};
 | 
			
		||||
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> input_tensors_;
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> output_tensors_;
 | 
			
		||||
  // TODO: see above
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> rope_weigths_;
 | 
			
		||||
 | 
			
		||||
  absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_GRAPH_BUILDER_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -1,475 +0,0 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm.h"
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/log/check.h"
 | 
			
		||||
#include "absl/log/log.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/text_generator/calculators/preprocessor_util.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/text_generator/calculators/sampler_util.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
#include "util/gtl/stl_logging.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> ApplyFinalProj(
 | 
			
		||||
    std::shared_ptr<Tensor> inter_layer, const UlmWeights& weights,
 | 
			
		||||
    XnnGraphBuilder& builder) {
 | 
			
		||||
  return builder.FullConn(inter_layer, weights.softmax_linear,
 | 
			
		||||
                          weights.softmax_bias);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
class OneTokenUlm : public Ulm {
 | 
			
		||||
 public:
 | 
			
		||||
  OneTokenUlm(std::unique_ptr<Ulm> full_ulm, XnnGraph&& other)
 | 
			
		||||
      : Ulm(std::move(other)), full_ulm_(std::move(full_ulm)) {}
 | 
			
		||||
  ~OneTokenUlm() override = default;
 | 
			
		||||
 | 
			
		||||
  absl::Status InitInputTokens(const std::vector<int>& input_ids) override {
 | 
			
		||||
    prev_ids_ = input_ids;
 | 
			
		||||
    MP_RETURN_IF_ERROR(full_ulm_->InitInputTokens(input_ids));
 | 
			
		||||
    // prev_id.size - 1 is the output.
 | 
			
		||||
    return full_ulm_->Run();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::Status GetNextToken(std::vector<int>* output_ids) override {
 | 
			
		||||
    size_t decode_step = prev_ids_.size() - 1;
 | 
			
		||||
    VLOG(2) << "Decode step " << decode_step;
 | 
			
		||||
 | 
			
		||||
    if (decode_step == ulm_params_.seq_size_T - 1) {
 | 
			
		||||
      return absl::OutOfRangeError(
 | 
			
		||||
          absl::StrCat("Hit max sequence length ", ulm_params_.seq_size_T));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    transformer_input_->Borrow(
 | 
			
		||||
        full_ulm_->transformer_input_->Slice(1, decode_step));
 | 
			
		||||
    atten_masks_->Borrow(full_ulm_->atten_masks_->Slice(0, decode_step));
 | 
			
		||||
    MP_RETURN_IF_ERROR(segment_pos_->LoadFromBuffer(
 | 
			
		||||
        full_ulm_->segment_pos_->Slice(0, decode_step)->Data()));
 | 
			
		||||
    for (auto& kv_cache : kv_cache_) {
 | 
			
		||||
      DCHECK(kv_cache.k_slice);
 | 
			
		||||
      DCHECK(kv_cache.v_slice);
 | 
			
		||||
      kv_cache.k_slice->Borrow(kv_cache.k_cache->Slice(1, decode_step));
 | 
			
		||||
      kv_cache.v_slice->Borrow(kv_cache.v_cache->Slice(1, decode_step));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    MP_RETURN_IF_ERROR(SetupRuntime());
 | 
			
		||||
    MP_RETURN_IF_ERROR(Run());
 | 
			
		||||
 | 
			
		||||
    RET_CHECK(logits_output_);
 | 
			
		||||
    DCHECK_EQ(logits_output_->num_elements, ulm_params_.voc_size_V);
 | 
			
		||||
 | 
			
		||||
    ASSIGN_OR_RETURN(*output_ids,
 | 
			
		||||
                     mediapipe::SampleNextToken(
 | 
			
		||||
                         logits_output_->DataAs<float>(),
 | 
			
		||||
                         /*batch_size=*/1,
 | 
			
		||||
                         /*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10,
 | 
			
		||||
                         /*top_p=*/1, /*temperature=*/-1));
 | 
			
		||||
    RET_CHECK_EQ(output_ids->size(), 1);
 | 
			
		||||
    prev_ids_.push_back(output_ids->at(0));
 | 
			
		||||
 | 
			
		||||
    return GetTokenEmbedding(
 | 
			
		||||
        *output_ids,
 | 
			
		||||
        pos_embedding_data_->Slice({decode_step + 1, 0})->DataAs<float>(),
 | 
			
		||||
        full_ulm_->transformer_input_->Slice({0, decode_step + 1, 0})
 | 
			
		||||
            ->DataAs<float>());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::unique_ptr<Ulm> full_ulm_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::SelfAttentionExcludeNorm(
 | 
			
		||||
    std::shared_ptr<Tensor> input, SelfAttentionArgs args,
 | 
			
		||||
    const SelfAttentionWeights& sa_weights, absl::SourceLocation loc) {
 | 
			
		||||
  // [B, 1|T, N, H]
 | 
			
		||||
  ASSIGN_OR_RETURN(auto k_proj, SelfAttentionProj(input, sa_weights.k_weight));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto q_proj, SelfAttentionProj(input, sa_weights.q_weight));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto v_proj, SelfAttentionProj(input, sa_weights.v_weight));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto query_proj_after_rope, Rope(q_proj, args.segment_pos));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto key_proj_after_rope, Rope(k_proj, args.segment_pos));
 | 
			
		||||
 | 
			
		||||
  if (args.cache) {
 | 
			
		||||
    RET_CHECK(args.cache->k_cache);
 | 
			
		||||
    RET_CHECK(args.cache->v_cache);
 | 
			
		||||
    // When cache is provided, there are 2 cases:
 | 
			
		||||
    if (*(input->dims.end() - 2) != 1) {
 | 
			
		||||
      // Building a normal graph, which is used to initialize cache.
 | 
			
		||||
      key_proj_after_rope->Borrow(args.cache->k_cache).MarkOutput();
 | 
			
		||||
      v_proj->Borrow(args.cache->v_cache).MarkOutput();
 | 
			
		||||
    } else {
 | 
			
		||||
      // Building a one-token graph, which consumes initialized cache.
 | 
			
		||||
      key_proj_after_rope->MarkOutput();
 | 
			
		||||
      args.cache->k_slice = key_proj_after_rope;
 | 
			
		||||
      v_proj->MarkOutput();
 | 
			
		||||
      args.cache->v_slice = v_proj;
 | 
			
		||||
 | 
			
		||||
      ASSIGN_OR_RETURN(key_proj_after_rope,
 | 
			
		||||
                       NewInput(args.cache->k_cache->dims));
 | 
			
		||||
      key_proj_after_rope->Borrow(args.cache->k_cache);
 | 
			
		||||
      ASSIGN_OR_RETURN(v_proj, NewInput(args.cache->v_cache->dims));
 | 
			
		||||
      v_proj->Borrow(args.cache->v_cache);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // encoded, [B, 1|T, N, H]
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      auto kqv_merged,
 | 
			
		||||
      DotAttention(query_proj_after_rope, key_proj_after_rope, v_proj,
 | 
			
		||||
                   args.atten_mask, sa_weights.per_dim_scale));
 | 
			
		||||
 | 
			
		||||
  const size_t B = kqv_merged->dims[0];
 | 
			
		||||
  const size_t T_or_1 = kqv_merged->dims[1];
 | 
			
		||||
  const size_t NH = kqv_merged->num_elements / (B * T_or_1);
 | 
			
		||||
  ASSIGN_OR_RETURN(auto outcome_reshaped, Reshape(kqv_merged, {B, T_or_1, NH}));
 | 
			
		||||
 | 
			
		||||
  return MatMul(outcome_reshaped, sa_weights.post_proj_weight,
 | 
			
		||||
                {.transpose = false});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
UlmBuilder::SelfAttentionIncludeResidual(std::shared_ptr<Tensor> input,
 | 
			
		||||
                                         SelfAttentionArgs args,
 | 
			
		||||
                                         const SelfAttentionWeights& params,
 | 
			
		||||
                                         absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto pre_attention, RmsNorm(input, params.pre_norm));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      auto post_attention,
 | 
			
		||||
      SelfAttentionExcludeNorm(pre_attention, std::move(args), params));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto post_norm, RmsNorm(post_attention, params.post_norm));
 | 
			
		||||
 | 
			
		||||
  return ElementAdd(input, post_norm);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::FeedForwardExcludeResidual(
 | 
			
		||||
    std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto first_rms_norm, RmsNorm(input, params.pre_norm));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto layer_1, FullConn(first_rms_norm, params.layer_1_weight,
 | 
			
		||||
                                          params.layer_1_bias));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto layer_1_gate_before_gelu,
 | 
			
		||||
                   FullConn(first_rms_norm, params.layer_1_gate_weight,
 | 
			
		||||
                            params.layer_1_gate_bias));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto layer_1_gate, Gelu(layer_1_gate_before_gelu));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto layer_1_and_gate, ElementMul(layer_1, layer_1_gate));
 | 
			
		||||
  if (params.opt_padding) {
 | 
			
		||||
    // activations *= 1.0 - paddings
 | 
			
		||||
    ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f));
 | 
			
		||||
    ASSIGN_OR_RETURN(tmp, ElementMul(layer_1_and_gate, tmp));
 | 
			
		||||
    ASSIGN_OR_RETURN(layer_1_and_gate, ElementAdd(tmp, layer_1_and_gate));
 | 
			
		||||
  }
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      auto layer_2,
 | 
			
		||||
      FullConn(layer_1_and_gate, params.layer_2_weight, params.layer_2_bias));
 | 
			
		||||
  if (params.opt_padding) {
 | 
			
		||||
    // activations *= 1.0 - paddings
 | 
			
		||||
    ASSIGN_OR_RETURN(auto tmp, ElementMul(params.opt_padding, -1.0f));
 | 
			
		||||
    ASSIGN_OR_RETURN(tmp, ElementMul(layer_2, tmp));
 | 
			
		||||
    ASSIGN_OR_RETURN(layer_2, ElementAdd(tmp, layer_2));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return RmsNorm(layer_2, params.post_norm);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmBuilder::FeedForwardIncludeResidual(
 | 
			
		||||
    std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
 | 
			
		||||
    absl::SourceLocation loc) {
 | 
			
		||||
  ASSIGN_OR_RETURN(auto before_residual,
 | 
			
		||||
                   FeedForwardExcludeResidual(input, params));
 | 
			
		||||
  return ElementAdd(before_residual, input);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateUlm(
 | 
			
		||||
    absl::string_view weights_folder, const UlmParams& ulm_params,
 | 
			
		||||
    std::unique_ptr<RuntimeConfigs> runtime_configs) {
 | 
			
		||||
  auto weight_loader =
 | 
			
		||||
      std::make_unique<DefaultUlmWeightsLoader>(weights_folder, ulm_params);
 | 
			
		||||
  return CreateUlm(std::move(weight_loader), std::move(runtime_configs));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateOneTokenUlm(
 | 
			
		||||
    std::unique_ptr<UlmWeightsLoader> weight_loader,
 | 
			
		||||
    std::unique_ptr<RuntimeConfigs> runtime_configs) {
 | 
			
		||||
  UlmBuilder builder;
 | 
			
		||||
  // TODO: might be memory waste here, benchmark.
 | 
			
		||||
  weight_loader->SetBuilder(builder);
 | 
			
		||||
  ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
 | 
			
		||||
 | 
			
		||||
  UlmParams ulm_params = weight_loader->ulm_params();
 | 
			
		||||
  ulm_params.enable_kv_cache = true;
 | 
			
		||||
 | 
			
		||||
  weight_loader->ulm_params().enable_kv_cache = true;
 | 
			
		||||
  weight_loader->ulm_params().final_norm = false;
 | 
			
		||||
  weight_loader->ulm_params().final_project = false;
 | 
			
		||||
  ASSIGN_OR_RETURN(auto full_ulm, CreateUlm(std::move(weight_loader)));
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B, 1,
 | 
			
		||||
                                                 ulm_params.model_dim_D}));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto atten_masks,
 | 
			
		||||
                   builder.NewInput({1, ulm_params.seq_size_T}));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto segment_pos,
 | 
			
		||||
                   builder.NewWeight({1, ulm_params.head_dim_H}));
 | 
			
		||||
  // To allocate buffer before creating runtime.
 | 
			
		||||
  MP_RETURN_IF_ERROR(segment_pos->LoadFromVec({}, /*exact_match=*/false));
 | 
			
		||||
 | 
			
		||||
  std::vector<KVCache>& kv_cache = full_ulm->kv_cache_;
 | 
			
		||||
  RET_CHECK_EQ(kv_cache.size(), ulm_params.num_transformer_M);
 | 
			
		||||
 | 
			
		||||
  auto inter_layer = input;
 | 
			
		||||
  for (int i = 0; i < ulm_params.num_transformer_M; ++i) {
 | 
			
		||||
    const auto& sa = weights.sas[i];
 | 
			
		||||
    ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual(
 | 
			
		||||
                                   inter_layer,
 | 
			
		||||
                                   {.atten_mask = atten_masks,
 | 
			
		||||
                                    .segment_pos = segment_pos,
 | 
			
		||||
                                    .cache = &kv_cache[i]},
 | 
			
		||||
                                   sa));
 | 
			
		||||
 | 
			
		||||
    auto& ff = weights.ffs[i];
 | 
			
		||||
    // ff.opt_padding = paddings;
 | 
			
		||||
    ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> logits_output, transformer_output, normed_output;
 | 
			
		||||
 | 
			
		||||
  if (ulm_params.final_norm) {
 | 
			
		||||
    ASSIGN_OR_RETURN(inter_layer,
 | 
			
		||||
                     builder.RmsNorm(inter_layer, weights.final_ln_scale));
 | 
			
		||||
    normed_output = inter_layer;
 | 
			
		||||
    normed_output->MarkOutput();
 | 
			
		||||
  }
 | 
			
		||||
  if (ulm_params.final_project) {
 | 
			
		||||
    RET_CHECK(weights.softmax_linear);
 | 
			
		||||
    ASSIGN_OR_RETURN(logits_output,
 | 
			
		||||
                     ApplyFinalProj(inter_layer, weights, builder));
 | 
			
		||||
    logits_output->MarkOutput();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs)));
 | 
			
		||||
  Ulm* full_ulm_p = full_ulm.get();
 | 
			
		||||
  auto result =
 | 
			
		||||
      std::make_unique<OneTokenUlm>(std::move(full_ulm), std::move(*graph));
 | 
			
		||||
  {
 | 
			
		||||
    Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D};
 | 
			
		||||
    result->pos_embedding_data_ =
 | 
			
		||||
        std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
 | 
			
		||||
    result->pos_embedding_data_->Borrow(full_ulm_p->pos_embedding_data_);
 | 
			
		||||
  }
 | 
			
		||||
  result->transformer_input_ = input;
 | 
			
		||||
  result->transformer_output_ = transformer_output;
 | 
			
		||||
  result->normed_output_ = normed_output;
 | 
			
		||||
  result->logits_output_ = logits_output;
 | 
			
		||||
  result->segment_pos_ = segment_pos;
 | 
			
		||||
  result->atten_masks_ = atten_masks;
 | 
			
		||||
  if (ulm_params.use_padding) {
 | 
			
		||||
    // result->paddings_ = paddings;
 | 
			
		||||
  }
 | 
			
		||||
  result->kv_cache_ = std::move(kv_cache);
 | 
			
		||||
 | 
			
		||||
  result->weights_ = std::move(weights);
 | 
			
		||||
  result->ulm_params_ = ulm_params;
 | 
			
		||||
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::unique_ptr<Ulm>> Ulm::CreateUlm(
 | 
			
		||||
    std::unique_ptr<UlmWeightsLoader> weight_loader,
 | 
			
		||||
    std::unique_ptr<RuntimeConfigs> runtime_configs) {
 | 
			
		||||
  UlmBuilder builder;
 | 
			
		||||
  weight_loader->SetBuilder(builder);
 | 
			
		||||
  const auto& ulm_params = weight_loader->ulm_params();
 | 
			
		||||
  RET_CHECK_NE(ulm_params.batch_size_B, 0);
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto input, builder.NewInput({ulm_params.batch_size_B,
 | 
			
		||||
                                                 ulm_params.seq_size_T,
 | 
			
		||||
                                                 ulm_params.model_dim_D}));
 | 
			
		||||
  ASSIGN_OR_RETURN(auto atten_masks, builder.NewInput({ulm_params.seq_size_T,
 | 
			
		||||
                                                       ulm_params.seq_size_T}));
 | 
			
		||||
  VLOG(1) << "atten mask id " << atten_masks->tensor_id;
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      auto segment_pos,
 | 
			
		||||
      builder.NewWeight({ulm_params.seq_size_T, ulm_params.head_dim_H}));
 | 
			
		||||
  MP_RETURN_IF_ERROR(FillXnnRoPEWeights(*segment_pos));
 | 
			
		||||
  VLOG(1) << "segment pos id " << segment_pos->tensor_id;
 | 
			
		||||
  std::shared_ptr<Tensor> paddings;
 | 
			
		||||
  if (ulm_params.use_padding) {
 | 
			
		||||
    ASSIGN_OR_RETURN(paddings, builder.NewInput({ulm_params.batch_size_B,
 | 
			
		||||
                                                 ulm_params.seq_size_T, 1}));
 | 
			
		||||
    VLOG(1) << "paddings id " << paddings->tensor_id;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto weights, weight_loader->LoadWeights());
 | 
			
		||||
  std::vector<KVCache> kv_cache;
 | 
			
		||||
 | 
			
		||||
  auto inter_layer = input;
 | 
			
		||||
  for (int i = 0; i < ulm_params.num_transformer_M; ++i) {
 | 
			
		||||
    const auto& sa = weights.sas[i];
 | 
			
		||||
    KVCache* cache = nullptr;
 | 
			
		||||
    if (ulm_params.enable_kv_cache) {
 | 
			
		||||
      auto k_cache = std::make_shared<Tensor>(
 | 
			
		||||
          Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T,
 | 
			
		||||
                           ulm_params.n_heads_N, ulm_params.head_dim_H});
 | 
			
		||||
      MP_RETURN_IF_ERROR(k_cache->LoadFromVec({}, /*exact_match=*/false));
 | 
			
		||||
      auto v_cache = std::make_shared<Tensor>(
 | 
			
		||||
          Tensor::DimsType{ulm_params.batch_size_B, ulm_params.seq_size_T,
 | 
			
		||||
                           ulm_params.n_heads_N, ulm_params.head_dim_H});
 | 
			
		||||
      MP_RETURN_IF_ERROR(v_cache->LoadFromVec({}, /*exact_match=*/false));
 | 
			
		||||
      kv_cache.push_back(KVCache{.k_cache = k_cache, .v_cache = v_cache});
 | 
			
		||||
      cache = &kv_cache.back();
 | 
			
		||||
    }
 | 
			
		||||
    ASSIGN_OR_RETURN(auto tmp, builder.SelfAttentionIncludeResidual(
 | 
			
		||||
                                   inter_layer,
 | 
			
		||||
                                   {.atten_mask = atten_masks,
 | 
			
		||||
                                    .segment_pos = segment_pos,
 | 
			
		||||
                                    .cache = cache},
 | 
			
		||||
                                   sa));
 | 
			
		||||
 | 
			
		||||
    auto& ff = weights.ffs[i];
 | 
			
		||||
    ff.opt_padding = paddings;
 | 
			
		||||
    ASSIGN_OR_RETURN(inter_layer, builder.FeedForwardIncludeResidual(tmp, ff));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> logits_output, transformer_output, normed_output;
 | 
			
		||||
 | 
			
		||||
  if (!ulm_params.final_norm && !ulm_params.final_project) {
 | 
			
		||||
    transformer_output = inter_layer;
 | 
			
		||||
    transformer_output->MarkOutput();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (ulm_params.final_norm) {
 | 
			
		||||
    ASSIGN_OR_RETURN(inter_layer,
 | 
			
		||||
                     builder.RmsNorm(inter_layer, weights.final_ln_scale));
 | 
			
		||||
    normed_output = inter_layer;
 | 
			
		||||
    normed_output->MarkOutput();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (ulm_params.final_project) {
 | 
			
		||||
    RET_CHECK(weights.softmax_linear);
 | 
			
		||||
    ASSIGN_OR_RETURN(logits_output,
 | 
			
		||||
                     ApplyFinalProj(inter_layer, weights, builder));
 | 
			
		||||
    logits_output->MarkOutput();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto graph, builder.Build(std::move(runtime_configs)));
 | 
			
		||||
  auto ulm = std::make_unique<Ulm>(std::move(*graph));
 | 
			
		||||
  {
 | 
			
		||||
    ASSIGN_OR_RETURN(auto pos_embedding_data,
 | 
			
		||||
                     mediapipe::PositionEmbedding(ulm_params.seq_size_T,
 | 
			
		||||
                                                  ulm_params.model_dim_D));
 | 
			
		||||
    Tensor::DimsType dims{ulm_params.seq_size_T, ulm_params.model_dim_D};
 | 
			
		||||
    ulm->pos_embedding_data_ =
 | 
			
		||||
        std::make_shared<Tensor>(std::move(dims), xnn_datatype_fp32);
 | 
			
		||||
    MP_RETURN_IF_ERROR(
 | 
			
		||||
        ulm->pos_embedding_data_->LoadFromVec(pos_embedding_data));
 | 
			
		||||
  }
 | 
			
		||||
  ulm->transformer_input_ = input;
 | 
			
		||||
  ulm->transformer_output_ = transformer_output;
 | 
			
		||||
  ulm->normed_output_ = normed_output;
 | 
			
		||||
  ulm->logits_output_ = logits_output;
 | 
			
		||||
  ulm->segment_pos_ = segment_pos;
 | 
			
		||||
  ulm->atten_masks_ = atten_masks;
 | 
			
		||||
  if (ulm_params.use_padding) {
 | 
			
		||||
    ulm->paddings_ = paddings;
 | 
			
		||||
  }
 | 
			
		||||
  ulm->kv_cache_ = std::move(kv_cache);
 | 
			
		||||
 | 
			
		||||
  ulm->weights_ = std::move(weights);
 | 
			
		||||
  ulm->ulm_params_ = ulm_params;
 | 
			
		||||
 | 
			
		||||
  return ulm;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Ulm::InitInputTokens(const std::vector<int>& input_ids) {
 | 
			
		||||
  prev_ids_ = input_ids;
 | 
			
		||||
 | 
			
		||||
  constexpr float neg_value = 0.7 * std::numeric_limits<float>::lowest();
 | 
			
		||||
  const auto& seq_size = ulm_params_.seq_size_T;
 | 
			
		||||
  std::vector<float> attention_array(seq_size * seq_size, neg_value);
 | 
			
		||||
  for (int i = 0; i < seq_size; ++i) {
 | 
			
		||||
    for (int j = 0; j < seq_size; ++j) {
 | 
			
		||||
      if (i < input_ids.size() && j < input_ids.size()) {
 | 
			
		||||
        attention_array[seq_size * i + j] = 0;
 | 
			
		||||
      } else if (i >= seq_size && j <= i) {
 | 
			
		||||
        attention_array[seq_size * i + j] = 0;
 | 
			
		||||
      } else {
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  MP_RETURN_IF_ERROR(atten_masks_->LoadFromVec(attention_array));
 | 
			
		||||
 | 
			
		||||
  MP_RETURN_IF_ERROR(GetTokenEmbedding(input_ids,
 | 
			
		||||
                                       pos_embedding_data_->DataAs<float>(),
 | 
			
		||||
                                       transformer_input_->DataAs<float>()));
 | 
			
		||||
  return SetupRuntime();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Ulm::GetNextToken(std::vector<int>* output_ids) {
 | 
			
		||||
  VLOG(2) << "Decode step " << prev_ids_.size() - 1;
 | 
			
		||||
 | 
			
		||||
  MP_RETURN_IF_ERROR(Run());
 | 
			
		||||
 | 
			
		||||
  RET_CHECK(logits_output_);
 | 
			
		||||
  std::shared_ptr<Tensor> logits =
 | 
			
		||||
      logits_output_->Slice({0, prev_ids_.size() - 1, 0});
 | 
			
		||||
  DCHECK_EQ(logits->num_elements, ulm_params_.voc_size_V);
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(*output_ids,
 | 
			
		||||
                   mediapipe::SampleNextToken(
 | 
			
		||||
                       logits->DataAs<float>(),
 | 
			
		||||
                       /*batch_size=*/1,
 | 
			
		||||
                       /*vocab_size=*/ulm_params_.voc_size_V, /*top_k=*/10,
 | 
			
		||||
                       /*top_p=*/1, /*temperature=*/-1));
 | 
			
		||||
  RET_CHECK_EQ(output_ids->size(), 1);
 | 
			
		||||
  prev_ids_.push_back(output_ids->at(0));
 | 
			
		||||
 | 
			
		||||
  return GetTokenEmbedding(
 | 
			
		||||
      *output_ids,
 | 
			
		||||
      pos_embedding_data_->Slice({prev_ids_.size() - 1, 0})->DataAs<float>(),
 | 
			
		||||
      transformer_input_->Slice({0, prev_ids_.size() - 1, 0})->DataAs<float>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Ulm::GetTokenEmbedding(const std::vector<int>& ids,
 | 
			
		||||
                                    const float* pos_embedding_data,
 | 
			
		||||
                                    float* embedding) {
 | 
			
		||||
  auto token_embedding = weights_.token_embedding ? weights_.token_embedding
 | 
			
		||||
                                                  : weights_.softmax_linear;
 | 
			
		||||
  RET_CHECK(token_embedding->dims[0] == ulm_params_.voc_size_V)
 | 
			
		||||
      << "shape must be [vocab_size, _], such that following Slice() makes "
 | 
			
		||||
         "sense.";
 | 
			
		||||
  for (size_t id : ids) {
 | 
			
		||||
    memcpy(embedding, token_embedding->Slice(0, id)->Data(),
 | 
			
		||||
           ulm_params_.model_dim_D * sizeof(float));
 | 
			
		||||
    for (size_t i = 0; i < ulm_params_.model_dim_D; ++i) {
 | 
			
		||||
      embedding[i] += pos_embedding_data[i];
 | 
			
		||||
    }
 | 
			
		||||
    pos_embedding_data += ulm_params_.model_dim_D;
 | 
			
		||||
    embedding += ulm_params_.model_dim_D;
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -1,127 +0,0 @@
 | 
			
		|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
class Ulm : public XnnGraph {
 | 
			
		||||
 public:
 | 
			
		||||
  using UlmParams = UlmParams;
 | 
			
		||||
 | 
			
		||||
  explicit Ulm(XnnGraph&& other) : XnnGraph(std::move(other)) {}
 | 
			
		||||
  ~Ulm() override = default;
 | 
			
		||||
 | 
			
		||||
  // Creating ULM graph with default params. The default param corresponds to
 | 
			
		||||
  // ULM1B 256k model.
 | 
			
		||||
  static absl::StatusOr<std::unique_ptr<Ulm>> CreateUlm(
 | 
			
		||||
      absl::string_view weights_folder,
 | 
			
		||||
      const UlmParams& ulm_params =
 | 
			
		||||
          UlmParams{
 | 
			
		||||
              .num_transformer_M = 18,
 | 
			
		||||
              .batch_size_B = 1,
 | 
			
		||||
              .seq_size_T = 16,
 | 
			
		||||
              .model_dim_D = 1536,
 | 
			
		||||
              .hidden_dim_HD = 8 * 1536,
 | 
			
		||||
              .head_dim_H = 128,
 | 
			
		||||
              .n_heads_N = 12,
 | 
			
		||||
              .voc_size_V = 256128,
 | 
			
		||||
          },
 | 
			
		||||
      std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
 | 
			
		||||
  static absl::StatusOr<std::unique_ptr<Ulm>> CreateUlm(
 | 
			
		||||
      std::unique_ptr<UlmWeightsLoader> weight_loader,
 | 
			
		||||
      std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
 | 
			
		||||
  // Build the graph for one-token inference.
 | 
			
		||||
  static absl::StatusOr<std::unique_ptr<Ulm>> CreateOneTokenUlm(
 | 
			
		||||
      std::unique_ptr<UlmWeightsLoader> weight_loader,
 | 
			
		||||
      std::unique_ptr<RuntimeConfigs> runtime_configs = nullptr);
 | 
			
		||||
 | 
			
		||||
  // (Re)Initialize with input token ids. This will reset the cache, mask etc.
 | 
			
		||||
  virtual absl::Status InitInputTokens(const std::vector<int>& input_ids);
 | 
			
		||||
 | 
			
		||||
  // Get the next token id.
 | 
			
		||||
  virtual absl::Status GetNextToken(std::vector<int>* output_ids);
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  friend class OneTokenUlm;
 | 
			
		||||
  friend class UlmTest;
 | 
			
		||||
  friend class UlmBuilder;
 | 
			
		||||
 | 
			
		||||
  // Enable if enable_kv_cache
 | 
			
		||||
  struct KVCache {
 | 
			
		||||
    std::shared_ptr<Tensor> k_cache;
 | 
			
		||||
    std::shared_ptr<Tensor> v_cache;
 | 
			
		||||
    std::shared_ptr<Tensor> k_slice;
 | 
			
		||||
    std::shared_ptr<Tensor> v_slice;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  absl::Status GetTokenEmbedding(const std::vector<int>& ids,
 | 
			
		||||
                                 const float* pos_embedding_data,
 | 
			
		||||
                                 float* embedding);
 | 
			
		||||
 | 
			
		||||
  UlmWeights weights_;
 | 
			
		||||
  UlmParams ulm_params_;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> pos_embedding_data_;
 | 
			
		||||
  std::shared_ptr<Tensor> atten_masks_;
 | 
			
		||||
  std::shared_ptr<Tensor> segment_pos_;
 | 
			
		||||
  std::shared_ptr<Tensor> paddings_;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> transformer_input_;
 | 
			
		||||
  std::shared_ptr<Tensor> transformer_output_;
 | 
			
		||||
  std::shared_ptr<Tensor> normed_output_;
 | 
			
		||||
  std::shared_ptr<Tensor> logits_output_;
 | 
			
		||||
 | 
			
		||||
  // Previous ids, including prompt.
 | 
			
		||||
  std::vector<int> prev_ids_;
 | 
			
		||||
  // If enable_kv_cache, expect a mask of [0, ... 0, 1, 0, 0...], size 1 x T.
 | 
			
		||||
  std::shared_ptr<Tensor> decode_step_mask_;
 | 
			
		||||
  // [1, 1, ..., 1, 0, 0...], applied on cache
 | 
			
		||||
  std::shared_ptr<Tensor> decode_step_mask_for_cache_;
 | 
			
		||||
  std::vector<KVCache> kv_cache_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class UlmBuilder : public XnnGraphBuilder {
 | 
			
		||||
 public:
 | 
			
		||||
  struct SelfAttentionArgs {
 | 
			
		||||
    std::shared_ptr<Tensor> atten_mask;
 | 
			
		||||
    std::shared_ptr<Tensor> segment_pos;
 | 
			
		||||
 | 
			
		||||
    Ulm::KVCache* cache = nullptr;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionExcludeNorm(
 | 
			
		||||
      std::shared_ptr<Tensor> input, SelfAttentionArgs args,
 | 
			
		||||
      const SelfAttentionWeights& sa_weights,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> SelfAttentionIncludeResidual(
 | 
			
		||||
      std::shared_ptr<Tensor> input, SelfAttentionArgs args,
 | 
			
		||||
      const SelfAttentionWeights& params,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardExcludeResidual(
 | 
			
		||||
      std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> FeedForwardIncludeResidual(
 | 
			
		||||
      std::shared_ptr<Tensor> input, const FeedForwardWeights& params,
 | 
			
		||||
      absl::SourceLocation loc = absl::SourceLocation::current());
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -1,366 +0,0 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/ulm_weights.h"
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "file/base/filesystem.h"
 | 
			
		||||
#include "file/base/options.h"
 | 
			
		||||
#include "file/base/path.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
#include "third_party/XNNPACK/include/xnnpack.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefixHelper(
 | 
			
		||||
    XnnGraphBuilder& builder, absl::string_view prefix,
 | 
			
		||||
    const Tensor::DimsType& dims, size_t dim_scale_if_any) {
 | 
			
		||||
  RET_CHECK(!prefix.empty() && prefix.back() != '.');
 | 
			
		||||
  std::vector<std::string> filenames;
 | 
			
		||||
  auto s = file::Match(absl::StrCat(prefix, "*"), &filenames, file::Defaults());
 | 
			
		||||
  if (!s.ok()) {
 | 
			
		||||
    LOG(WARNING) << s;
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  } else if (filenames.empty()) {
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (filenames.size() == 1) {
 | 
			
		||||
    RET_CHECK_EQ(filenames[0], prefix);
 | 
			
		||||
    return builder.NewWeight(filenames[0], dims);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool is_quantized_tensor = false;
 | 
			
		||||
  for (const auto& filename : filenames) {
 | 
			
		||||
    if (absl::StrContains(filename, kQuantizedScaleSuffix)) {
 | 
			
		||||
      is_quantized_tensor = true;
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  RET_CHECK(is_quantized_tensor)
 | 
			
		||||
      << "At least one of {" << filenames << "} must be quantize scale file.";
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> result;
 | 
			
		||||
  result = std::make_shared<QCTensor>(dims, dim_scale_if_any);
 | 
			
		||||
 | 
			
		||||
  MP_RETURN_IF_ERROR(result->LoadFromFile(prefix));
 | 
			
		||||
  builder.NewWeight(result);
 | 
			
		||||
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TransposeSelfAttentionWeight(
 | 
			
		||||
    const UlmWeightsLoader& loader, std::shared_ptr<Tensor>& original_weight,
 | 
			
		||||
    absl::string_view cache_file_prefix) {
 | 
			
		||||
  const auto& ulm_param = loader.ulm_params();
 | 
			
		||||
  RET_CHECK(original_weight);
 | 
			
		||||
 | 
			
		||||
  std::optional<int> from_cache =
 | 
			
		||||
      original_weight->GetMetadata(UlmWeights::kKeyLoadedFromCache);
 | 
			
		||||
  if (from_cache && *from_cache) {
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (auto s = original_weight->DumpToFile(cache_file_prefix); !s.ok()) {
 | 
			
		||||
    LOG(WARNING) << s;
 | 
			
		||||
  } else {
 | 
			
		||||
    MP_RETURN_IF_ERROR(original_weight->LoadFromFile(cache_file_prefix));
 | 
			
		||||
  }
 | 
			
		||||
  loader.builder().NewWeight(original_weight);
 | 
			
		||||
  original_weight->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight,
 | 
			
		||||
                               ulm_param.n_heads_N);
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
absl::Status PrepareTokenEmbeddingDecorator::Decorate(
 | 
			
		||||
    const UlmWeightsLoader& loader, UlmWeights& weight) {
 | 
			
		||||
  if (weight.token_embedding) {
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const auto& ulm_params = loader.ulm_params();
 | 
			
		||||
  absl::string_view cache_path = loader.ulm_params().weight_cache_path;
 | 
			
		||||
  std::string token_embedding_cache_path =
 | 
			
		||||
      cache_path.empty() ? "" : file::JoinPath(cache_path, "token_embedding.w");
 | 
			
		||||
  // 1. try cache
 | 
			
		||||
  if (!token_embedding_cache_path.empty()) {
 | 
			
		||||
    auto token_embedding =
 | 
			
		||||
        Tensor::FromFile(token_embedding_cache_path,
 | 
			
		||||
                         {ulm_params.voc_size_V, ulm_params.model_dim_D});
 | 
			
		||||
    if (token_embedding.ok()) {
 | 
			
		||||
      weight.token_embedding = *token_embedding;
 | 
			
		||||
      return absl::OkStatus();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 2. fill embedding from softmax_linear
 | 
			
		||||
  auto& softmax_linear = *weight.softmax_linear;
 | 
			
		||||
  RET_CHECK(softmax_linear.dims[0] == ulm_params.voc_size_V) << softmax_linear;
 | 
			
		||||
  if (softmax_linear.datatype == xnn_datatype_fp32) {
 | 
			
		||||
    weight.token_embedding = softmax_linear.View();
 | 
			
		||||
  } else if (softmax_linear.datatype == xnn_datatype_qcint8) {
 | 
			
		||||
    ASSIGN_OR_RETURN(weight.token_embedding, softmax_linear.ConvertToF32());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  float* embedding_data = weight.token_embedding->DataAs<float>();
 | 
			
		||||
  for (size_t i = 0; i < softmax_linear.num_elements; ++i) {
 | 
			
		||||
    embedding_data[i] *= std::sqrt(loader.ulm_params().model_dim_D);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // 3. save cache
 | 
			
		||||
  if (!token_embedding_cache_path.empty()) {
 | 
			
		||||
    MP_RETURN_IF_ERROR(
 | 
			
		||||
        weight.token_embedding->DumpToFile(token_embedding_cache_path));
 | 
			
		||||
    return weight.token_embedding->LoadFromFile(token_embedding_cache_path);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TransposeSelfAttentionWeightDecorator::Decorate(
 | 
			
		||||
    const UlmWeightsLoader& loader, UlmWeights& weight) {
 | 
			
		||||
  absl::string_view cache_path = loader.ulm_params().weight_cache_path;
 | 
			
		||||
  if (cache_path.empty()) {
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (size_t i = 0; i < weight.sas.size(); ++i) {
 | 
			
		||||
    auto& sa = weight.sas[i];
 | 
			
		||||
    auto prefix = absl::StrCat(UlmWeightsLoader::kTransformerWeightPrefix, i,
 | 
			
		||||
                               ".self_attention.");
 | 
			
		||||
    MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
 | 
			
		||||
        loader, sa.k_weight,
 | 
			
		||||
        file::JoinPath(cache_path, absl::StrCat(prefix, "k.w"))));
 | 
			
		||||
    MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
 | 
			
		||||
        loader, sa.q_weight,
 | 
			
		||||
        file::JoinPath(cache_path, absl::StrCat(prefix, "q.w"))));
 | 
			
		||||
    MP_RETURN_IF_ERROR(TransposeSelfAttentionWeight(
 | 
			
		||||
        loader, sa.v_weight,
 | 
			
		||||
        file::JoinPath(cache_path, absl::StrCat(prefix, "v.w"))));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> UlmWeightsLoader::LoadFromAbsPathPrefix(
 | 
			
		||||
    absl::string_view prefix, const Tensor::DimsType& dims,
 | 
			
		||||
    size_t dim_scale_if_any) const {
 | 
			
		||||
  return LoadFromAbsPathPrefixHelper(*builder_, prefix, dims, dim_scale_if_any);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
UlmWeightsLoader::TryCacheThenLoadSelfAttention(
 | 
			
		||||
    absl::string_view filename_prefix) const {
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      auto r,
 | 
			
		||||
      TryCacheThenLoadWeightTranspose(
 | 
			
		||||
          filename_prefix,
 | 
			
		||||
          {params_.model_dim_D, params_.n_heads_N * params_.head_dim_H}, 1));
 | 
			
		||||
  r->SetMetadata(XnnGraphBuilder::kKeySelfAttentionReshapedWeight,
 | 
			
		||||
                 params_.n_heads_N);
 | 
			
		||||
  return r;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
UlmWeightsLoader::TryCacheThenLoadFeedForward(
 | 
			
		||||
    absl::string_view filename_prefix,
 | 
			
		||||
    std::optional<Tensor::DimsType> dims) const {
 | 
			
		||||
  if (!dims) {
 | 
			
		||||
    dims = {params_.model_dim_D, params_.hidden_dim_HD};
 | 
			
		||||
  }
 | 
			
		||||
  return TryCacheThenLoadWeightTranspose(filename_prefix, *dims, 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
UlmWeightsLoader::TryCacheThenLoadWeightTranspose(
 | 
			
		||||
    absl::string_view filename_prefix, Tensor::DimsType original_dims,
 | 
			
		||||
    size_t original_dim_cale) const {
 | 
			
		||||
  if (!params_.weight_cache_path.empty()) {
 | 
			
		||||
    auto cache_full_prefix =
 | 
			
		||||
        file::JoinPath(params_.weight_cache_path, filename_prefix);
 | 
			
		||||
    Tensor::DimsType cache_dim{original_dims.rbegin(), original_dims.rend()};
 | 
			
		||||
    ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix(
 | 
			
		||||
                                 cache_full_prefix, std::move(cache_dim),
 | 
			
		||||
                                 /*dim_scale_if_any=*/1 - original_dim_cale));
 | 
			
		||||
    if (r) {
 | 
			
		||||
      r->SetMetadata(UlmWeights::kKeyLoadedFromCache, 1);
 | 
			
		||||
      return r;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(auto r, LoadFromAbsPathPrefix(
 | 
			
		||||
                               file::JoinPath(weight_path_, filename_prefix),
 | 
			
		||||
                               std::move(original_dims),
 | 
			
		||||
                               /*dim_scale_if_any=*/original_dim_cale));
 | 
			
		||||
  RET_CHECK(r) << file::JoinPath(weight_path_, filename_prefix);
 | 
			
		||||
  r = r->Transpose();
 | 
			
		||||
  builder_->NewWeight(r);
 | 
			
		||||
  return r;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<FeedForwardWeights> UlmWeightsLoader::LoadFeedForward(
 | 
			
		||||
    int layer_id) {
 | 
			
		||||
  absl::string_view weights_folder = weight_path_;
 | 
			
		||||
  const auto& params = params_;
 | 
			
		||||
  auto ff_file_prefix =
 | 
			
		||||
      absl::StrCat(kTransformerWeightPrefix, layer_id, ".ff_layer.");
 | 
			
		||||
  auto ff_prefix = file::JoinPath(weights_folder, ff_file_prefix);
 | 
			
		||||
  FeedForwardWeights feed_forward;
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      feed_forward.pre_norm,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "pre_layer_norm.scale"),
 | 
			
		||||
                            {params.model_dim_D}));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      feed_forward.post_norm,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "post_layer_norm.scale"),
 | 
			
		||||
                            {params.model_dim_D}));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      feed_forward.layer_1_bias,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1.bias.b"),
 | 
			
		||||
                            {params.hidden_dim_HD}));
 | 
			
		||||
  ASSIGN_OR_RETURN(feed_forward.layer_1_weight,
 | 
			
		||||
                   TryCacheThenLoadFeedForward(
 | 
			
		||||
                       absl::StrCat(ff_file_prefix, "ffn_layer1.linear.w")));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      feed_forward.layer_1_gate_bias,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer1_gate.bias.b"),
 | 
			
		||||
                            {params.hidden_dim_HD}));
 | 
			
		||||
  ASSIGN_OR_RETURN(feed_forward.layer_1_gate_weight,
 | 
			
		||||
                   TryCacheThenLoadFeedForward(absl::StrCat(
 | 
			
		||||
                       ff_file_prefix, "ffn_layer1_gate.linear.w")));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      feed_forward.layer_2_bias,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(ff_prefix, "ffn_layer2.bias.b"),
 | 
			
		||||
                            {params.model_dim_D}, /*dim_scale_if_any=*/0));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      feed_forward.layer_2_weight,
 | 
			
		||||
      TryCacheThenLoadFeedForward(
 | 
			
		||||
          absl::StrCat(ff_file_prefix, "ffn_layer2.linear.w"),
 | 
			
		||||
          Tensor::DimsType{params.hidden_dim_HD, params.model_dim_D}));
 | 
			
		||||
 | 
			
		||||
  return feed_forward;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<SelfAttentionWeights> UlmWeightsLoader::LoadSelfAttention(
 | 
			
		||||
    int layer_id) {
 | 
			
		||||
  absl::string_view weights_folder = weight_path_;
 | 
			
		||||
  const auto& params = params_;
 | 
			
		||||
  SelfAttentionWeights self_attention;
 | 
			
		||||
 | 
			
		||||
  auto sa_file_prefix = absl::StrCat(kTransformerWeightPrefix, layer_id);
 | 
			
		||||
  auto sa_prefix = file::JoinPath(weights_folder, sa_file_prefix);
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      self_attention.pre_norm,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".pre_layer_norm.scale"),
 | 
			
		||||
                            {params.model_dim_D}));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      self_attention.post_norm,
 | 
			
		||||
      LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, ".post_layer_norm.scale"),
 | 
			
		||||
                            {params.model_dim_D}));
 | 
			
		||||
 | 
			
		||||
  absl::StrAppend(&sa_file_prefix, ".self_attention.");
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      self_attention.k_weight,
 | 
			
		||||
      TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "k.w")));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      self_attention.q_weight,
 | 
			
		||||
      TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "q.w")));
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      self_attention.v_weight,
 | 
			
		||||
      TryCacheThenLoadSelfAttention(absl::StrCat(sa_file_prefix, "v.w")));
 | 
			
		||||
 | 
			
		||||
  sa_prefix = file::JoinPath(weights_folder, sa_file_prefix);
 | 
			
		||||
  ASSIGN_OR_RETURN(self_attention.per_dim_scale,
 | 
			
		||||
                   LoadFromAbsPathPrefix(
 | 
			
		||||
                       absl::StrCat(sa_prefix, "per_dim_scale.per_dim_scale"),
 | 
			
		||||
                       {params.head_dim_H}));
 | 
			
		||||
  ASSIGN_OR_RETURN(self_attention.post_proj_weight,
 | 
			
		||||
                   LoadFromAbsPathPrefix(absl::StrCat(sa_prefix, "post.w"),
 | 
			
		||||
                                         {params.model_dim_D,
 | 
			
		||||
                                          params.n_heads_N * params.head_dim_H},
 | 
			
		||||
                                         /*dim_scale_if_any=*/0));
 | 
			
		||||
 | 
			
		||||
  return self_attention;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<UlmWeights> UlmWeightsLoader::LoadWeights() {
 | 
			
		||||
  absl::string_view weights_folder = weight_path_;
 | 
			
		||||
  const auto& params = params_;
 | 
			
		||||
  UlmWeights result;
 | 
			
		||||
 | 
			
		||||
  for (int layer_id = 0; layer_id < params.num_transformer_M; ++layer_id) {
 | 
			
		||||
    ASSIGN_OR_RETURN(auto ff, LoadFeedForward(layer_id));
 | 
			
		||||
    result.ffs.push_back(std::move(ff));
 | 
			
		||||
    ASSIGN_OR_RETURN(auto sa, LoadSelfAttention(layer_id));
 | 
			
		||||
    result.sas.push_back(std::move(sa));
 | 
			
		||||
  }
 | 
			
		||||
  if (params.final_norm) {
 | 
			
		||||
    ASSIGN_OR_RETURN(result.final_ln_scale,
 | 
			
		||||
                     LoadFromAbsPathPrefix(
 | 
			
		||||
                         file::JoinPath(weights_folder, kFinalScaleFilename),
 | 
			
		||||
                         {params.model_dim_D}));
 | 
			
		||||
  }
 | 
			
		||||
  ASSIGN_OR_RETURN(result.softmax_bias,
 | 
			
		||||
                   LoadFromAbsPathPrefix(
 | 
			
		||||
                       file::JoinPath(weights_folder, kLogitsFfnBiasFilename),
 | 
			
		||||
                       {params.voc_size_V}));
 | 
			
		||||
  ASSIGN_OR_RETURN(result.softmax_linear,
 | 
			
		||||
                   TryCacheThenLoadWeightTranspose(
 | 
			
		||||
                       kLogitsFfnWeightFilename,
 | 
			
		||||
                       {params.model_dim_D, params.voc_size_V}, 1));
 | 
			
		||||
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BenchmarkUlmWeightsLoader::BenchmarkUlmWeightsLoader(const UlmParams& params,
 | 
			
		||||
                                                     xnn_datatype data_type)
 | 
			
		||||
    : DefaultUlmWeightsLoader("", params), data_type_(data_type) {
 | 
			
		||||
  params_.weight_cache_path.clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
BenchmarkUlmWeightsLoader::TryCacheThenLoadWeightTranspose(
 | 
			
		||||
    absl::string_view filename_prefix, Tensor::DimsType original_dims,
 | 
			
		||||
    size_t original_dim_cale) const {
 | 
			
		||||
  auto result = std::make_shared<QCTensor>(
 | 
			
		||||
      Tensor::DimsType{original_dims.rbegin(), original_dims.rend()},
 | 
			
		||||
      1 - original_dim_cale);
 | 
			
		||||
  auto real_data = std::make_shared<std::string>(result->num_elements, 0xA5);
 | 
			
		||||
  result->flat_data = std::shared_ptr<char>(real_data, real_data->data());
 | 
			
		||||
  auto real_scale = std::make_shared<std::vector<float>>(
 | 
			
		||||
      original_dims[original_dim_cale], 1.0f);
 | 
			
		||||
  result->scale_data = std::shared_ptr<float>(real_scale, real_scale->data());
 | 
			
		||||
  builder_->NewWeight(result);
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
BenchmarkUlmWeightsLoader::LoadFromAbsPathPrefix(
 | 
			
		||||
    absl::string_view prefix, const Tensor::DimsType& dims,
 | 
			
		||||
    size_t dim_scale_if_any) const {
 | 
			
		||||
  // If loader calls this function directly, it's always non-quantized weights.
 | 
			
		||||
  auto result = std::make_shared<Tensor>(dims);
 | 
			
		||||
  MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
 | 
			
		||||
  builder_->NewWeight(result);
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -1,192 +0,0 @@
 | 
			
		|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
 | 
			
		||||
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/graph_builder.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
#include "third_party/XNNPACK/include/xnnpack.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
struct UlmParams {
 | 
			
		||||
  size_t num_transformer_M = 18;
 | 
			
		||||
  size_t batch_size_B = 1;
 | 
			
		||||
  size_t seq_size_T = 16;
 | 
			
		||||
  size_t model_dim_D = 1536;
 | 
			
		||||
  size_t hidden_dim_HD = 8 * 1536;
 | 
			
		||||
  size_t head_dim_H = 128;
 | 
			
		||||
  size_t n_heads_N = 12;
 | 
			
		||||
  size_t voc_size_V = 32000;
 | 
			
		||||
 | 
			
		||||
  bool use_padding = true;
 | 
			
		||||
  bool final_norm = true;
 | 
			
		||||
  bool final_project = true;
 | 
			
		||||
 | 
			
		||||
  bool enable_kv_cache = false;
 | 
			
		||||
  // Path to store reshaped weights as cache. Set empty to disable caching.
 | 
			
		||||
  std::string weight_cache_path;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SelfAttentionWeights {
 | 
			
		||||
  std::shared_ptr<Tensor> pre_norm;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> k_weight;
 | 
			
		||||
  std::shared_ptr<Tensor> q_weight;
 | 
			
		||||
  std::shared_ptr<Tensor> v_weight;
 | 
			
		||||
  std::shared_ptr<Tensor> per_dim_scale;
 | 
			
		||||
  std::shared_ptr<Tensor> post_proj_weight;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> post_norm;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct FeedForwardWeights {
 | 
			
		||||
  std::shared_ptr<Tensor> pre_norm;
 | 
			
		||||
  std::shared_ptr<Tensor> layer_1_weight;
 | 
			
		||||
  std::shared_ptr<Tensor> layer_1_bias;
 | 
			
		||||
  std::shared_ptr<Tensor> layer_1_gate_weight;
 | 
			
		||||
  std::shared_ptr<Tensor> layer_1_gate_bias;
 | 
			
		||||
  std::shared_ptr<Tensor> layer_2_weight;
 | 
			
		||||
  std::shared_ptr<Tensor> layer_2_bias;
 | 
			
		||||
  std::shared_ptr<Tensor> post_norm;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> opt_padding;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct UlmWeights {
 | 
			
		||||
  std::vector<FeedForwardWeights> ffs;
 | 
			
		||||
  std::vector<SelfAttentionWeights> sas;
 | 
			
		||||
  std::shared_ptr<Tensor> final_ln_scale;
 | 
			
		||||
  std::shared_ptr<Tensor> softmax_linear;
 | 
			
		||||
  std::shared_ptr<Tensor> softmax_bias;
 | 
			
		||||
 | 
			
		||||
  // Optional. Usually softmax_linear can be used as embedding, but sometimes we
 | 
			
		||||
  // need to scale/transpose it.
 | 
			
		||||
  std::shared_ptr<Tensor> token_embedding;
 | 
			
		||||
 | 
			
		||||
  static constexpr absl::string_view kKeyLoadedFromCache{"loaded_from_cache"};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class UlmWeightsLoader {
 | 
			
		||||
 public:
 | 
			
		||||
  constexpr static absl::string_view kTransformerWeightPrefix{
 | 
			
		||||
      "params.lm.transformer.x_layers_"};
 | 
			
		||||
  constexpr static absl::string_view kFinalScaleFilename{
 | 
			
		||||
      "params.lm.final_ln.scale"};
 | 
			
		||||
  constexpr static absl::string_view kLogitsFfnBiasFilename{
 | 
			
		||||
      "params.lm.softmax.logits_ffn.bias.b"};
 | 
			
		||||
  constexpr static absl::string_view kLogitsFfnWeightFilename{
 | 
			
		||||
      "params.lm.softmax.logits_ffn.linear.w"};
 | 
			
		||||
 | 
			
		||||
  UlmWeightsLoader(absl::string_view weight_path, const UlmParams& params)
 | 
			
		||||
      : weight_path_(weight_path), params_(params) {}
 | 
			
		||||
  virtual ~UlmWeightsLoader() = default;
 | 
			
		||||
 | 
			
		||||
  void SetBuilder(XnnGraphBuilder& builder) { builder_ = &builder; }
 | 
			
		||||
 | 
			
		||||
  virtual absl::StatusOr<UlmWeights> LoadWeights();
 | 
			
		||||
 | 
			
		||||
  virtual absl::StatusOr<SelfAttentionWeights> LoadSelfAttention(int layer_id);
 | 
			
		||||
  virtual absl::StatusOr<FeedForwardWeights> LoadFeedForward(int layer_id);
 | 
			
		||||
 | 
			
		||||
  UlmParams& ulm_params() { return params_; }
 | 
			
		||||
  const UlmParams& ulm_params() const { return params_; }
 | 
			
		||||
  XnnGraphBuilder& builder() const { return *builder_; }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Find the files that matches prefix, then read from file.
 | 
			
		||||
  virtual absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
 | 
			
		||||
      absl::string_view prefix, const Tensor::DimsType& dims,
 | 
			
		||||
      size_t dim_scale_if_any) const;
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
 | 
			
		||||
      absl::string_view prefix, const Tensor::DimsType& dims) const {
 | 
			
		||||
    return LoadFromAbsPathPrefix(prefix, dims, 0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadSelfAttention(
 | 
			
		||||
      absl::string_view filename_prefix) const;
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadFeedForward(
 | 
			
		||||
      absl::string_view filename_prefix,
 | 
			
		||||
      std::optional<Tensor::DimsType> dims = std::nullopt) const;
 | 
			
		||||
  virtual absl::StatusOr<std::shared_ptr<Tensor>>
 | 
			
		||||
  TryCacheThenLoadWeightTranspose(absl::string_view filename_prefix,
 | 
			
		||||
                                  Tensor::DimsType original_dims,
 | 
			
		||||
                                  size_t original_dim_cale) const;
 | 
			
		||||
 | 
			
		||||
  std::string weight_path_;
 | 
			
		||||
  UlmParams params_;
 | 
			
		||||
  XnnGraphBuilder* builder_ = nullptr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Try: 1. load token embedding from cache; 2. fill token embedding by transpose
 | 
			
		||||
// softmax linear then scale; 3. dump token embedding to cache.
 | 
			
		||||
struct PrepareTokenEmbeddingDecorator {
 | 
			
		||||
  static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
 | 
			
		||||
};
 | 
			
		||||
struct TransposeSoftmaxWeightDecorator {
 | 
			
		||||
  static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
 | 
			
		||||
};
 | 
			
		||||
struct TransposeSelfAttentionWeightDecorator {
 | 
			
		||||
  // If KQV weight are reshaped, ignore.
 | 
			
		||||
  // If KQV weight are not properly shaped, load from cache if any, or build.
 | 
			
		||||
  // If KQV weight are missing, try loading from cache path, or fail if missing.
 | 
			
		||||
  static absl::Status Decorate(const UlmWeightsLoader&, UlmWeights&);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Apply some decoration (in order) to the weights loaded by base class.
 | 
			
		||||
template <class... Decorators>
 | 
			
		||||
class UlmWeightsLoaderWith : public UlmWeightsLoader {
 | 
			
		||||
 public:
 | 
			
		||||
  UlmWeightsLoaderWith(absl::string_view weight_path, const UlmParams& params)
 | 
			
		||||
      : UlmWeightsLoader(weight_path, params),
 | 
			
		||||
        decorators_{Decorators::Decorate...} {}
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<UlmWeights> LoadWeights() override {
 | 
			
		||||
    ASSIGN_OR_RETURN(auto result, UlmWeightsLoader::LoadWeights());
 | 
			
		||||
    for (const auto& decorator : decorators_) {
 | 
			
		||||
      MP_RETURN_IF_ERROR(decorator(*this, result));
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  std::vector<std::function<absl::Status(const UlmWeightsLoader&, UlmWeights&)>>
 | 
			
		||||
      decorators_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
using DefaultUlmWeightsLoader =
 | 
			
		||||
    UlmWeightsLoaderWith<TransposeSelfAttentionWeightDecorator,
 | 
			
		||||
                         PrepareTokenEmbeddingDecorator>;
 | 
			
		||||
 | 
			
		||||
// Generate weights with some random value.
 | 
			
		||||
class BenchmarkUlmWeightsLoader : public DefaultUlmWeightsLoader {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit BenchmarkUlmWeightsLoader(
 | 
			
		||||
      const UlmParams& params, xnn_datatype data_type = xnn_datatype_fp32);
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> TryCacheThenLoadWeightTranspose(
 | 
			
		||||
      absl::string_view filename_prefix, Tensor::DimsType original_dims,
 | 
			
		||||
      size_t original_dim_cale) const override;
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> LoadFromAbsPathPrefix(
 | 
			
		||||
      absl::string_view prefix, const Tensor::DimsType& dims,
 | 
			
		||||
      size_t dim_scale_if_any) const override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  xnn_datatype data_type_;
 | 
			
		||||
  std::shared_ptr<Tensor> random_value_buffer_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_ULM_WEIGHTS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -1,21 +0,0 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels) {
 | 
			
		||||
  std::vector<float> out_array(max_seq_len * num_channels);
 | 
			
		||||
  for (size_t ch_id = 0; ch_id < num_channels / 2; ++ch_id) {
 | 
			
		||||
    auto timescale = std::pow(1e-4, 2.0 * ch_id / num_channels);
 | 
			
		||||
    for (size_t seq_id = 0; seq_id < max_seq_len; ++seq_id) {
 | 
			
		||||
      auto sinusoid_inp = seq_id * timescale;
 | 
			
		||||
      out_array[seq_id * num_channels + ch_id] = cos(sinusoid_inp);
 | 
			
		||||
      out_array[seq_id * num_channels + ch_id + num_channels / 2] =
 | 
			
		||||
          sin(sinusoid_inp);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return out_array;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -1,61 +0,0 @@
 | 
			
		|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
 | 
			
		||||
 | 
			
		||||
#include <fcntl.h>
 | 
			
		||||
#include <sys/mman.h>
 | 
			
		||||
 | 
			
		||||
#include "absl/cleanup/cleanup.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "file/base/helpers.h"
 | 
			
		||||
#include "file/base/options.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
std::vector<float> FillXnnRoPEWeights(size_t max_seq_len, size_t num_channels);
 | 
			
		||||
 | 
			
		||||
// expect_size_bytes == 0 means don't check size.
 | 
			
		||||
template <typename element_type = char>
 | 
			
		||||
static absl::StatusOr<std::shared_ptr<element_type>> LoadBufferFromFile(
 | 
			
		||||
    absl::string_view file_path, bool use_mmap = true,
 | 
			
		||||
    size_t expect_size_bytes = 0) {
 | 
			
		||||
  if (use_mmap) {
 | 
			
		||||
    int fd = open(file_path.data(), O_RDONLY);
 | 
			
		||||
    RET_CHECK_GE(fd, 0) << "open " << file_path << " failed";
 | 
			
		||||
    auto cleanup = absl::MakeCleanup([fd] { close(fd); });
 | 
			
		||||
 | 
			
		||||
    const size_t size = lseek(fd, 0, SEEK_END);
 | 
			
		||||
    if (expect_size_bytes) {
 | 
			
		||||
      RET_CHECK_EQ(expect_size_bytes, size)
 | 
			
		||||
          << "File size " << size << ", expected " << expect_size_bytes
 | 
			
		||||
          << ", file path " << file_path;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void* data = mmap(/*addr=*/nullptr, size, /*prot=*/PROT_READ,
 | 
			
		||||
                      /*flags=*/MAP_SHARED, fd, /*offset=*/0);
 | 
			
		||||
    RET_CHECK_NE(data, MAP_FAILED);
 | 
			
		||||
    RET_CHECK_NE(data, nullptr);
 | 
			
		||||
 | 
			
		||||
    return std::shared_ptr<element_type>(static_cast<element_type*>(data),
 | 
			
		||||
                                         [](auto* p) {});
 | 
			
		||||
  } else {
 | 
			
		||||
    auto read_buffer = std::make_shared<std::string>();
 | 
			
		||||
    MP_RETURN_IF_ERROR(
 | 
			
		||||
        file::GetContents(file_path, read_buffer.get(), file::Defaults()));
 | 
			
		||||
 | 
			
		||||
    if (expect_size_bytes) {
 | 
			
		||||
      RET_CHECK_EQ(expect_size_bytes, read_buffer->size())
 | 
			
		||||
          << "File size " << read_buffer->size() << ", expected "
 | 
			
		||||
          << expect_size_bytes << ", file path " << file_path;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return std::shared_ptr<element_type>(
 | 
			
		||||
        read_buffer, reinterpret_cast<element_type*>(read_buffer->data()));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_UTILS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -1,358 +0,0 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/xnn_tensor.h"
 | 
			
		||||
 | 
			
		||||
#include <fcntl.h>
 | 
			
		||||
#include <sys/mman.h>
 | 
			
		||||
#include <unistd.h>
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <ostream>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/log/check.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "file/base/helpers.h"
 | 
			
		||||
#include "file/base/options.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
 | 
			
		||||
#include "third_party/XNNPACK/include/xnnpack.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos) {
 | 
			
		||||
  RET_CHECK_EQ(out_seg_pos.dims.size(), 2);
 | 
			
		||||
  const size_t max_seq_len = out_seg_pos.dims[0];
 | 
			
		||||
  const size_t num_channels = out_seg_pos.dims[1];
 | 
			
		||||
  return out_seg_pos.LoadFromVec(FillXnnRoPEWeights(max_seq_len, num_channels));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::ostream& operator<<(std::ostream& os, const Tensor& tensor) {
 | 
			
		||||
  os << "Tensor{dims=[" << tensor.dims << "], datatype=" << tensor.datatype
 | 
			
		||||
     << ", num_elements=" << tensor.num_elements << "}";
 | 
			
		||||
  return os;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::ostream& operator<<(std::ostream& os, const QCTensor& tensor) {
 | 
			
		||||
  os << "QCTensor{dims=[" << tensor.dims << "], dim_scale=" << tensor.dim_scale
 | 
			
		||||
     << " datatype=" << tensor.datatype
 | 
			
		||||
     << ", num_elements=" << tensor.num_elements << "}";
 | 
			
		||||
  return os;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Tensor::operator==(const Tensor& other) const {
 | 
			
		||||
  if (dims.size() != other.dims.size()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  } else if (datatype != other.datatype) {
 | 
			
		||||
    return false;
 | 
			
		||||
  } else {
 | 
			
		||||
    for (size_t i = 0; i < dims.size(); ++i) {
 | 
			
		||||
      if (dims[i] != other.dims[i]) {
 | 
			
		||||
        return false;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return 0 == memcmp(Data(), other.Data(), num_elements * ElementSize());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Tensor::AllocateBufferIfNeeded() {
 | 
			
		||||
  if (!flat_data) {
 | 
			
		||||
    auto real_buffer = std::make_shared<std::string>();
 | 
			
		||||
    real_buffer->reserve(num_elements * ElementSize() + XNN_EXTRA_BYTES);
 | 
			
		||||
    flat_data = std::shared_ptr<char>(real_buffer, real_buffer->data());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void* Tensor::Data() {
 | 
			
		||||
  DCHECK(flat_data)
 | 
			
		||||
      << "If this is weight, you may need to call one of the LoadFrom*()";
 | 
			
		||||
  return flat_data.get();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> Tensor::Slice(DimsType offset) {
 | 
			
		||||
  DCHECK(flat_data);
 | 
			
		||||
  CHECK_EQ(offset.size(), dims.size()) << offset << " vs. " << dims;
 | 
			
		||||
  // offset: [0, k, 0, 0], dims: [1, K, _, _]. dims before k must be 1.
 | 
			
		||||
  bool found_non_zero_offset = false;
 | 
			
		||||
  int index_k = -1;
 | 
			
		||||
  for (int i = 0; i < dims.size(); ++i) {
 | 
			
		||||
    if (found_non_zero_offset) {
 | 
			
		||||
      DCHECK_EQ(offset[i], 0);
 | 
			
		||||
    } else if (offset[i] != 0) {
 | 
			
		||||
      found_non_zero_offset = true;
 | 
			
		||||
      index_k = i;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  DCHECK(found_non_zero_offset) << offset;
 | 
			
		||||
 | 
			
		||||
  return Slice(index_k, offset[index_k]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> Tensor::Slice(size_t index, size_t offset) {
 | 
			
		||||
  size_t num_elements_offset = 1;
 | 
			
		||||
  DimsType new_dim = dims;
 | 
			
		||||
  for (int i = 0; i < dims.size(); ++i) {
 | 
			
		||||
    if (i < index) {
 | 
			
		||||
      DCHECK_EQ(dims[i], 1);
 | 
			
		||||
    } else if (i == index) {
 | 
			
		||||
      num_elements_offset *= offset;
 | 
			
		||||
      new_dim[i] = 1;
 | 
			
		||||
    } else {
 | 
			
		||||
      num_elements_offset *= dims[i];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto result = std::make_shared<Tensor>(std::move(new_dim), datatype);
 | 
			
		||||
  result->flat_data = std::shared_ptr<char>(
 | 
			
		||||
      flat_data, flat_data.get() + num_elements_offset * ElementSize());
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& Tensor::Borrow(std::shared_ptr<Tensor> other, size_t element_offset) {
 | 
			
		||||
  DCHECK_EQ(datatype, other->datatype);
 | 
			
		||||
  DCHECK_EQ(dims.size(), other->dims.size());
 | 
			
		||||
  flat_data = std::shared_ptr<char>(
 | 
			
		||||
      other->flat_data,
 | 
			
		||||
      other->flat_data.get() + element_offset * ElementSize());
 | 
			
		||||
  return *this;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> Tensor::View() { return View(dims); }
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> Tensor::View(DimsType as_dims, size_t) {
 | 
			
		||||
  auto result = std::make_shared<Tensor>(as_dims, datatype);
 | 
			
		||||
  DCHECK_LE(result->num_elements, num_elements);
 | 
			
		||||
  result->flat_data = flat_data;
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const void* Tensor::Data() const { return const_cast<Tensor*>(this)->Data(); }
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags) {
 | 
			
		||||
  uint32_t id;
 | 
			
		||||
  RET_CHECK_EQ(xnn_status_success,
 | 
			
		||||
               xnn_define_tensor_value(&subgraph, datatype, dims.size(),
 | 
			
		||||
                                       dims.data(), /*data=*/nullptr,
 | 
			
		||||
                                       /*external_id=*/tensor_id, flags, &id));
 | 
			
		||||
  if (tensor_id == XNN_INVALID_VALUE_ID) {
 | 
			
		||||
    RET_CHECK_NE(id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
    tensor_id = id;
 | 
			
		||||
  } else {
 | 
			
		||||
    RET_CHECK_EQ(id, tensor_id);
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineAsInput(xnn_subgraph& subgraph) {
 | 
			
		||||
  return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineAsOutput(xnn_subgraph& subgraph) {
 | 
			
		||||
  return DefineAsExternal(subgraph, XNN_VALUE_FLAG_EXTERNAL_OUTPUT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineAsIntermediateTensor(xnn_subgraph& subgraph) {
 | 
			
		||||
  RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
  return DefineAsExternal(subgraph, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
 | 
			
		||||
  RET_CHECK_EQ(
 | 
			
		||||
      xnn_status_success,
 | 
			
		||||
      xnn_define_tensor_value(&subgraph, datatype, dims.size(), dims.data(),
 | 
			
		||||
                              Data(), tensor_id, flags, &tensor_id));
 | 
			
		||||
  RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineWeight(xnn_subgraph& subgraph) {
 | 
			
		||||
  RET_CHECK_EQ(tensor_id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
  return DefineWeight(subgraph, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DefineRope(xnn_subgraph& subgraph) {
 | 
			
		||||
  RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
  return DefineWeight(subgraph, XNN_VALUE_FLAG_EXTERNAL_INPUT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::LoadFromBuffer(const void* buffer) {
 | 
			
		||||
  AllocateBufferIfNeeded();
 | 
			
		||||
  memcpy(Data(), buffer, num_elements * ElementSize());
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::LoadFromVec(const std::vector<float>& data,
 | 
			
		||||
                                 bool exact_match) {
 | 
			
		||||
  AllocateBufferIfNeeded();
 | 
			
		||||
  if (exact_match) {
 | 
			
		||||
    RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  memcpy(Data(), data.data(), data.size() * sizeof(float));
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::LoadFromVec(std::vector<float>&& data, bool exact_match) {
 | 
			
		||||
  if (exact_match) {
 | 
			
		||||
    RET_CHECK_EQ(num_elements * ElementSize(), data.size() * sizeof(float));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto real_buffer = std::make_shared<std::vector<float>>(std::move(data));
 | 
			
		||||
  if (real_buffer->size() < num_elements) {
 | 
			
		||||
    real_buffer->resize(num_elements);
 | 
			
		||||
  }
 | 
			
		||||
  flat_data = std::shared_ptr<char>(
 | 
			
		||||
      real_buffer, reinterpret_cast<char*>(real_buffer->data()));
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DumpToBuffer(void* buffer) {
 | 
			
		||||
  memcpy(buffer, Data(), num_elements * ElementSize());
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DumpToVec(std::vector<float>& out_data, bool exact_match) {
 | 
			
		||||
  if (exact_match) {
 | 
			
		||||
    RET_CHECK_EQ(num_elements * ElementSize(), out_data.size() * sizeof(float));
 | 
			
		||||
  } else {
 | 
			
		||||
    out_data.resize(num_elements);
 | 
			
		||||
  }
 | 
			
		||||
  memcpy(out_data.data(), Data(), num_elements * ElementSize());
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::DumpToFile(absl::string_view file_path) {
 | 
			
		||||
  return file::SetContents(
 | 
			
		||||
      file_path,
 | 
			
		||||
      absl::string_view(flat_data.get(), num_elements * ElementSize()),
 | 
			
		||||
      file::Defaults());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status Tensor::LoadFromFile(absl::string_view file_path, bool use_mmap,
 | 
			
		||||
                                  bool exact_match) {
 | 
			
		||||
  const size_t expected_size_in_bytes =
 | 
			
		||||
      exact_match ? num_elements * ElementSize() : 0;
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(flat_data, LoadBufferFromFile(file_path, use_mmap,
 | 
			
		||||
                                                 expected_size_in_bytes));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> Tensor::Transpose() {
 | 
			
		||||
  DCHECK_EQ(dims.size(), 2);
 | 
			
		||||
  DimsType out_dims{dims.rbegin(), dims.rend()};
 | 
			
		||||
  auto result = std::make_shared<Tensor>(std::move(out_dims), datatype);
 | 
			
		||||
  result->AllocateBufferIfNeeded();
 | 
			
		||||
  xnn_status s;
 | 
			
		||||
  const DimsType perm{1, 0};
 | 
			
		||||
  if (datatype == xnn_datatype_fp32) {
 | 
			
		||||
    s = xnn_run_transpose_nd_x32(Data(), result->Data(), dims.size(),
 | 
			
		||||
                                 dims.data(), perm.data(),
 | 
			
		||||
                                 /*flags=*/0, /*threadpool=*/nullptr);
 | 
			
		||||
  } else {
 | 
			
		||||
    LOG(FATAL) << "Need update to support new type";
 | 
			
		||||
  }
 | 
			
		||||
  DCHECK_EQ(s, xnn_status_success);
 | 
			
		||||
  return (s == xnn_status_success) ? result : nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> Tensor::ConvertToF32() {
 | 
			
		||||
  auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32);
 | 
			
		||||
  MP_RETURN_IF_ERROR(result->LoadFromBuffer(Data()));
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status QCTensor::LoadFromFile(absl::string_view quantized_weight_filename,
 | 
			
		||||
                                    absl::string_view scale_filename,
 | 
			
		||||
                                    bool use_mmap, bool exact_match) {
 | 
			
		||||
  size_t scale_element_size = dims[dim_scale];
 | 
			
		||||
 | 
			
		||||
  ASSIGN_OR_RETURN(flat_data,
 | 
			
		||||
                   LoadBufferFromFile(quantized_weight_filename, use_mmap,
 | 
			
		||||
                                      exact_match ? num_elements : 0));
 | 
			
		||||
  ASSIGN_OR_RETURN(scale_data,
 | 
			
		||||
                   LoadBufferFromFile<float>(
 | 
			
		||||
                       scale_filename, use_mmap,
 | 
			
		||||
                       exact_match ? scale_element_size * sizeof(float) : 0));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status QCTensor::DumpToFile(absl::string_view file_path) {
 | 
			
		||||
  MP_RETURN_IF_ERROR(file::SetContents(
 | 
			
		||||
      file_path,
 | 
			
		||||
      absl::string_view(flat_data.get(), num_elements * ElementSize()),
 | 
			
		||||
      file::Defaults()));
 | 
			
		||||
  return file::SetContents(
 | 
			
		||||
      absl::StrCat(file_path, kQuantizedScaleSuffix),
 | 
			
		||||
      absl::string_view(reinterpret_cast<char*>(scale_data.get()),
 | 
			
		||||
                        dims[dim_scale] * sizeof(float)),
 | 
			
		||||
      file::Defaults());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status QCTensor::DefineWeight(xnn_subgraph& subgraph, uint32_t flags) {
 | 
			
		||||
  RET_CHECK_EQ(
 | 
			
		||||
      xnn_status_success,
 | 
			
		||||
      xnn_define_channelwise_quantized_tensor_value(
 | 
			
		||||
          &subgraph, datatype, scale_data.get(), dims.size(), dim_scale,
 | 
			
		||||
          dims.data(), Data(), XNN_INVALID_VALUE_ID, flags, &tensor_id))
 | 
			
		||||
      << *this;
 | 
			
		||||
  RET_CHECK_NE(tensor_id, XNN_INVALID_VALUE_ID);
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void QCTensor::AllocateBufferIfNeeded() {
 | 
			
		||||
  Tensor::AllocateBufferIfNeeded();
 | 
			
		||||
  if (!scale_data) {
 | 
			
		||||
    auto real_buffer = std::make_shared<std::vector<float>>();
 | 
			
		||||
    real_buffer->reserve(dims[dim_scale]);
 | 
			
		||||
    scale_data = std::shared_ptr<float>(real_buffer, real_buffer->data());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> QCTensor::Transpose() {
 | 
			
		||||
  DCHECK_EQ(dims.size(), 2);
 | 
			
		||||
  size_t channel_size = dims[dim_scale];
 | 
			
		||||
  DimsType out_dims{dims.rbegin(), dims.rend()};
 | 
			
		||||
  auto result = std::make_shared<QCTensor>(std::move(out_dims), 1 - dim_scale);
 | 
			
		||||
  result->AllocateBufferIfNeeded();
 | 
			
		||||
  memcpy(result->scale_data.get(), scale_data.get(),
 | 
			
		||||
         channel_size * sizeof(float));
 | 
			
		||||
  xnn_status s;
 | 
			
		||||
  const DimsType perm{1, 0};
 | 
			
		||||
  if (datatype == xnn_datatype_qcint8) {
 | 
			
		||||
    s = xnn_run_transpose_nd_x8(Data(), result->Data(), dims.size(),
 | 
			
		||||
                                dims.data(), perm.data(),
 | 
			
		||||
                                /*flags=*/0, /*threadpool=*/nullptr);
 | 
			
		||||
  } else {
 | 
			
		||||
    LOG(FATAL) << "Need update to support new type";
 | 
			
		||||
  }
 | 
			
		||||
  DCHECK_EQ(s, xnn_status_success);
 | 
			
		||||
  return (s == xnn_status_success) ? result : nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::shared_ptr<Tensor>> QCTensor::ConvertToF32() {
 | 
			
		||||
  auto result = std::make_shared<Tensor>(dims, xnn_datatype_fp32);
 | 
			
		||||
  // TODO: proper implement.
 | 
			
		||||
  LOG(WARNING) << "This is fake impl";
 | 
			
		||||
  MP_RETURN_IF_ERROR(result->LoadFromVec({}, /*exact_match=*/false));
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Tensor> QCTensor::View(DimsType as_dims,
 | 
			
		||||
                                       size_t dim_scale_if_any) {
 | 
			
		||||
  auto result = std::make_shared<QCTensor>(as_dims, dim_scale_if_any);
 | 
			
		||||
  DCHECK_LE(result->num_elements, num_elements);
 | 
			
		||||
  result->flat_data = flat_data;
 | 
			
		||||
  result->scale_data = scale_data;
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -1,202 +0,0 @@
 | 
			
		|||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
 | 
			
		||||
 | 
			
		||||
#include <fcntl.h>
 | 
			
		||||
#include <sys/mman.h>
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <ostream>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/cleanup/cleanup.h"
 | 
			
		||||
#include "absl/container/flat_hash_map.h"
 | 
			
		||||
#include "absl/log/check.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "file/base/helpers.h"
 | 
			
		||||
#include "file/base/options.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_macros.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/utils/xnn_utils/utils.h"
 | 
			
		||||
#include "third_party/XNNPACK/include/xnnpack.h"
 | 
			
		||||
#include "util/gtl/stl_logging.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace xnn_utils {
 | 
			
		||||
 | 
			
		||||
static constexpr absl::string_view kQuantizedScaleSuffix{"_quantized_scale"};
 | 
			
		||||
static constexpr absl::string_view kSparsityParamsSuffix{"_sparsity_params"};
 | 
			
		||||
 | 
			
		||||
struct Tensor {
 | 
			
		||||
  using DimsType = std::vector<size_t>;
 | 
			
		||||
 | 
			
		||||
  explicit Tensor(DimsType in_dims, xnn_datatype datatype_ = xnn_datatype_fp32)
 | 
			
		||||
      : dims(std::move(in_dims)),
 | 
			
		||||
        num_elements(dims.empty() ? 0
 | 
			
		||||
                                  : std::accumulate(std::begin(dims),
 | 
			
		||||
                                                    std::end(dims), size_t(1),
 | 
			
		||||
                                                    std::multiplies<size_t>())),
 | 
			
		||||
        datatype(datatype_) {}
 | 
			
		||||
  Tensor(Tensor&& other) = default;
 | 
			
		||||
 | 
			
		||||
  Tensor& operator=(const Tensor& other) = delete;
 | 
			
		||||
  Tensor& operator=(Tensor&& other) = default;
 | 
			
		||||
 | 
			
		||||
  virtual ~Tensor() = default;
 | 
			
		||||
 | 
			
		||||
  bool operator==(const Tensor& other) const;
 | 
			
		||||
 | 
			
		||||
  void SetMetadata(absl::string_view key, int value) { metadata[key] = value; }
 | 
			
		||||
 | 
			
		||||
  std::optional<int> GetMetadata(absl::string_view key) const {
 | 
			
		||||
    if (metadata.contains(key)) {
 | 
			
		||||
      return metadata.at(key);
 | 
			
		||||
    }
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Read weights from file.
 | 
			
		||||
  template <xnn_datatype xnn_datatype_ = xnn_datatype_fp32>
 | 
			
		||||
  static absl::StatusOr<std::shared_ptr<Tensor>> FromFile(
 | 
			
		||||
      absl::string_view file_path, DimsType dims, bool use_mmap = true) {
 | 
			
		||||
    auto result = std::make_shared<Tensor>(std::move(dims), xnn_datatype_);
 | 
			
		||||
 | 
			
		||||
    MP_RETURN_IF_ERROR(
 | 
			
		||||
        result->LoadFromFile(file_path, use_mmap, /*exact_match=*/true));
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual absl::Status DefineAsExternal(xnn_subgraph& subgraph, uint32_t flags);
 | 
			
		||||
  absl::Status DefineAsInput(xnn_subgraph& subgraph);
 | 
			
		||||
  absl::Status DefineAsOutput(xnn_subgraph& subgraph);
 | 
			
		||||
  absl::Status DefineAsIntermediateTensor(xnn_subgraph& subgraph);
 | 
			
		||||
  virtual absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags);
 | 
			
		||||
  absl::Status DefineWeight(xnn_subgraph& subgraph);
 | 
			
		||||
  absl::Status DefineRope(xnn_subgraph& subgraph);
 | 
			
		||||
 | 
			
		||||
  absl::Status LoadFromBuffer(const void* buffer);
 | 
			
		||||
  absl::Status LoadFromVec(const std::vector<float>& data,
 | 
			
		||||
                           bool exact_match = true);
 | 
			
		||||
  absl::Status LoadFromVec(std::vector<float>&& data, bool exact_match = true);
 | 
			
		||||
  absl::Status LoadFromFile(absl::string_view file_path) {
 | 
			
		||||
    return LoadFromFile(file_path, true, true);
 | 
			
		||||
  }
 | 
			
		||||
  virtual absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
 | 
			
		||||
                                    bool exact_match);
 | 
			
		||||
 | 
			
		||||
  absl::Status DumpToBuffer(void* buffer);
 | 
			
		||||
  absl::Status DumpToVec(std::vector<float>& out_data, bool exact_match = true);
 | 
			
		||||
  virtual absl::Status DumpToFile(absl::string_view file_path);
 | 
			
		||||
 | 
			
		||||
  // If ith offset is 0, view's ith dim equals to original ith dim, otherwise 1.
 | 
			
		||||
  std::shared_ptr<Tensor> Slice(DimsType offset);
 | 
			
		||||
  // Slice along the `index`th dimension, offset at this dimension.
 | 
			
		||||
  std::shared_ptr<Tensor> Slice(size_t index, size_t offset);
 | 
			
		||||
 | 
			
		||||
  // Point the underline data to the borrowed tensor's data.
 | 
			
		||||
  Tensor& Borrow(std::shared_ptr<Tensor>, size_t element_offset = 0);
 | 
			
		||||
  std::shared_ptr<Tensor> View();
 | 
			
		||||
  virtual std::shared_ptr<Tensor> View(DimsType as_dims,
 | 
			
		||||
                                       size_t dim_scale_if_any = 0);
 | 
			
		||||
 | 
			
		||||
  Tensor& MarkOutput() {
 | 
			
		||||
    AllocateBufferIfNeeded();
 | 
			
		||||
    is_output_tensor = true;
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void* Data();
 | 
			
		||||
  const void* Data() const;
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  T* DataAs() {
 | 
			
		||||
    DCHECK_EQ(ElementSize(), sizeof(T));
 | 
			
		||||
    return static_cast<T*>(Data());
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  const T* DataAs() const {
 | 
			
		||||
    return static_cast<const T*>(Data());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual std::shared_ptr<Tensor> Transpose();
 | 
			
		||||
 | 
			
		||||
  virtual absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32();
 | 
			
		||||
 | 
			
		||||
  DimsType dims;
 | 
			
		||||
  size_t num_elements = 0;
 | 
			
		||||
  xnn_datatype datatype = xnn_datatype_invalid;
 | 
			
		||||
  uint32_t tensor_id = XNN_INVALID_VALUE_ID;
 | 
			
		||||
 | 
			
		||||
  // shared_ptr to make TensorMetadata copyable.
 | 
			
		||||
  std::shared_ptr<char> flat_data;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  friend class XnnGraphBuilder;
 | 
			
		||||
  friend class XnnGraph;
 | 
			
		||||
 | 
			
		||||
  // Actually allocate buffer unless necessary.
 | 
			
		||||
  virtual void AllocateBufferIfNeeded();
 | 
			
		||||
 | 
			
		||||
  virtual size_t ElementSize() const { return 4; }
 | 
			
		||||
 | 
			
		||||
  bool is_output_tensor = false;
 | 
			
		||||
 | 
			
		||||
  absl::flat_hash_map<std::string, int> metadata;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
 | 
			
		||||
 | 
			
		||||
// Channelwise Quantized.
 | 
			
		||||
struct QCTensor : public Tensor {
 | 
			
		||||
  explicit QCTensor(DimsType in_dims, size_t dim_scale_if_any)
 | 
			
		||||
      : Tensor(std::move(in_dims)), dim_scale(dim_scale_if_any) {
 | 
			
		||||
    datatype = xnn_datatype_qcint8;
 | 
			
		||||
    CHECK_LT(dim_scale, 4);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void AllocateBufferIfNeeded() override;
 | 
			
		||||
  size_t ElementSize() const override { return 1; }
 | 
			
		||||
 | 
			
		||||
  virtual absl::Status LoadFromFile(absl::string_view quantized_weight_filename,
 | 
			
		||||
                                    absl::string_view scale_filename,
 | 
			
		||||
                                    bool use_mmap, bool exact_match);
 | 
			
		||||
  // Append kQuantizedScaleSuffix to use as scale filename.
 | 
			
		||||
  absl::Status LoadFromFile(absl::string_view file_path, bool use_mmap,
 | 
			
		||||
                            bool exact_match) override {
 | 
			
		||||
    return LoadFromFile(file_path,
 | 
			
		||||
                        absl::StrCat(file_path, kQuantizedScaleSuffix),
 | 
			
		||||
                        use_mmap, exact_match);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::Status DumpToFile(absl::string_view file_path) override;
 | 
			
		||||
 | 
			
		||||
  absl::Status DefineWeight(xnn_subgraph& subgraph, uint32_t flags) override;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> Transpose() override;
 | 
			
		||||
 | 
			
		||||
  absl::StatusOr<std::shared_ptr<Tensor>> ConvertToF32() override;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<Tensor> View(DimsType as_dims,
 | 
			
		||||
                               size_t dim_scale_if_any) override;
 | 
			
		||||
 | 
			
		||||
  std::shared_ptr<float> scale_data;
 | 
			
		||||
  // Index of the dimension to scale.
 | 
			
		||||
  size_t dim_scale;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::ostream& operator<<(std::ostream& os, const QCTensor& tensor);
 | 
			
		||||
 | 
			
		||||
absl::Status FillXnnRoPEWeights(Tensor& out_seg_pos);
 | 
			
		||||
 | 
			
		||||
}  // namespace xnn_utils
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_TEXT_UTILS_XNN_UTILS_XNN_TENSOR_H_
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user