增加Windows onnxruntime cuda推理代码

This commit is contained in:
liuyulvv 2022-08-05 18:24:12 +08:00
parent 5a548cb77e
commit 3a6def6fee
5 changed files with 209 additions and 0 deletions

View File

@ -226,6 +226,12 @@ new_local_repository(
path = "D:\\opencv\\build", path = "D:\\opencv\\build",
) )
new_local_repository(
name = "windows_onnxruntime",
build_file = "@//third_party:onnxruntime_windows.BUILD",
path = "D:\\onnxruntime\\onnxruntime-win-x64-gpu-1.12.0",
)
http_archive( http_archive(
name = "android_opencv", name = "android_opencv",
build_file = "@//third_party:opencv_android.BUILD", build_file = "@//third_party:opencv_android.BUILD",

View File

@ -245,6 +245,38 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "inference_calculator_onnx_cuda",
srcs = [
"inference_calculator_onnx_cuda.cc",
],
copts = select({
# TODO: fix tensor.h not to require this, if possible
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":inference_calculator_interface",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@windows_onnxruntime//:onnxruntime",
] + select({
"//conditions:default": [
"//mediapipe/util:cpu_util",
],
}) + select({
"//conditions:default": [],
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
}),
alwayslink = 1,
)
cc_library( cc_library(
name = "inference_calculator_gl_if_compute_shader_available", name = "inference_calculator_gl_if_compute_shader_available",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],

View File

@ -145,6 +145,10 @@ struct InferenceCalculatorCpu : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorCpu"; static constexpr char kCalculatorName[] = "InferenceCalculatorCpu";
}; };
struct InferenceCalculatorOnnxCUDA : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorOnnxCUDA";
};
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,150 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "onnxruntime_cxx_api.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/interpreter_builder.h"
#include <cstring>
#include <memory>
#include <string>
#include <vector>
namespace mediapipe {
namespace api2 {
namespace {
int64_t value_size_of(const std::vector<int64_t>& dims) {
if (dims.empty()) return 0;
int64_t value_size = 1;
for (const auto& size : dims) value_size *= size;
return value_size;
}
} // namespace
class InferenceCalculatorOnnxCUDAImpl
: public NodeImpl<InferenceCalculatorOnnxCUDA, InferenceCalculatorOnnxCUDAImpl> {
public:
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
absl::Status LoadModel(const std::string& path);
Ort::Env env_;
std::unique_ptr<Ort::Session> session_;
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<const char*> m_input_names;
std::vector<const char*> m_output_names;
};
absl::Status InferenceCalculatorOnnxCUDAImpl::UpdateContract(
CalculatorContract* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required.";
return absl::OkStatus();
}
absl::Status InferenceCalculatorCpuImpl::LoadModel(const std::string& path) {
auto model_path = std::wstring(path.begin(), path.end());
Ort::SessionOptions session_options;
OrtCUDAProviderOptions cuda_options;
session_options.AppendExecutionProvider_CUDA(cuda_options);
session_ = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
size_t num_input_nodes = session_->GetInputCount();
size_t num_output_nodes = session_->GetOutputCount();
m_input_names.reserve(num_input_nodes);
m_output_names.reserve(num_output_nodes);
for (int i = 0; i < num_input_nodes; i++) {
char* input_name = session_->GetInputName(i, allocator);
m_input_names.push_back(input_name);
}
for (int i = 0; i < num_output_nodes; i++) {
char* output_name = session_->GetOutputName(i, allocator);
m_output_names.push_back(output_name);
}
return absl::OkStatus();
}
absl::Status InferenceCalculatorOnnxCUDAImpl::Open(CalculatorContext* cc) {
const auto& options = cc->Options<mediapipe::InferenceCalculatorOptions>();
if (!options.model_path().empty()) {
return LoadModel(options.model_path());
}
if (!options.landmark_path().empty()) {
return LoadModel(options.landmark_path());
}
return absl::Status(mediapipe::StatusCode::kNotFound, "Must specify Onnx model path.");
}
absl::Status InferenceCalculatorOnnxCUDAImpl::Process(CalculatorContext* cc) {
if (kInTensors(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = *kInTensors(cc);
RET_CHECK(!input_tensors.empty());
auto input_tensor_type = int(input_tensors[0].element_type());
std::vector<Ort::Value> ort_input_tensors;
ort_input_tensors.reserve(input_tensors.size());
for (const auto& tensor : input_tensors) {
auto& inputDims = tensor.shape().dims;
std::vector<int64_t> src_dims{inputDims[0], inputDims[1], inputDims[2], inputDims[3]};
auto src_value_size = value_size_of(src_dims);
auto input_tensor_view = tensor.GetCpuReadView();
auto input_tensor_buffer = const_cast<float*>(input_tensor_view.buffer<float>());
auto tmp_tensor = Ort::Value::CreateTensor<float>(memory_info_handler, input_tensor_buffer, src_value_size, src_dims.data(), src_dims.size());
ort_input_tensors.emplace_back(std::move(tmp_tensor));
}
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
std::vector<Ort::Value> onnx_output_tensors;
try {
onnx_output_tensors = session_->Run(
Ort::RunOptions{nullptr}, m_input_names.data(),
ort_input_tensors.data(), ort_input_tensors.size(), m_output_names.data(),
m_output_names.size());
} catch (Ort::Exception& e) {
LOG(ERROR) << "Run error msg:" << e.what();
}
for (const auto& tensor : onnx_output_tensors) {
auto info = tensor.GetTensorTypeAndShapeInfo();
auto dims = info.GetShape();
std::vector<int> tmp_dims;
for (const auto& i : dims) {
tmp_dims.push_back(i);
}
output_tensors->emplace_back(Tensor::ElementType::kFloat32, Tensor::Shape{tmp_dims});
auto cpu_view = output_tensors->back().GetCpuWriteView();
std::memcpy(cpu_view.buffer<float>(), tensor.GetTensorData<float>(), output_tensors->back().bytes());
}
kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus();
}
absl::Status InferenceCalculatorOnnxCUDAImpl::Close(CalculatorContext* cc) {
interpreter_ = nullptr;
delegate_ = nullptr;
return absl::OkStatus();
}
} // namespace api2
} // namespace mediapipe

17
third_party/onnxruntime_windows.BUILD vendored Normal file
View File

@ -0,0 +1,17 @@
cc_library(
name = "onnxruntime",
srcs = [
"lib/onnxruntime.dll",
"lib/onnxruntime.lib",
"lib/onnxruntime_providers_cuda.dll",
"lib/onnxruntime_providers_cuda.lib",
"lib/onnxruntime_providers_shared.dll",
"lib/onnxruntime_providers_shared.lib",
"lib/onnxruntime_providers_tensorrt.dll",
"lib/onnxruntime_providers_tensorrt.lib",
],
hdrs = glob(["include/*.h*"]),
includes = ["include/"],
linkstatic = 1,
visibility = ["//visibility:public"],
)