Project import generated by Copybara.
GitOrigin-RevId: 3ce19771d2586aeb611fff75bb7627721cf5d36b
This commit is contained in:
parent
4dc4b19ddb
commit
d3f98334bf
|
@ -51,6 +51,8 @@ RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /u
|
||||||
RUN pip3 install --upgrade setuptools
|
RUN pip3 install --upgrade setuptools
|
||||||
RUN pip3 install wheel
|
RUN pip3 install wheel
|
||||||
RUN pip3 install future
|
RUN pip3 install future
|
||||||
|
RUN pip3 install absl-py
|
||||||
|
RUN pip3 install numpy
|
||||||
RUN pip3 install six==1.14.0
|
RUN pip3 install six==1.14.0
|
||||||
RUN pip3 install tensorflow==2.2.0
|
RUN pip3 install tensorflow==2.2.0
|
||||||
RUN pip3 install tf_slim
|
RUN pip3 install tf_slim
|
||||||
|
|
|
@ -216,6 +216,50 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "inference_runner",
|
||||||
|
hdrs = ["inference_runner.h"],
|
||||||
|
copts = select({
|
||||||
|
# TODO: fix tensor.h not to require this, if possible
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc", # enable reference-counting
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "inference_interpreter_delegate_runner",
|
||||||
|
srcs = ["inference_interpreter_delegate_runner.cc"],
|
||||||
|
hdrs = ["inference_interpreter_delegate_runner.h"],
|
||||||
|
copts = select({
|
||||||
|
# TODO: fix tensor.h not to require this, if possible
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc", # enable reference-counting
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":inference_runner",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "inference_calculator_cpu",
|
name = "inference_calculator_cpu",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -232,22 +276,63 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_interface",
|
":inference_calculator_interface",
|
||||||
|
":inference_calculator_utils",
|
||||||
|
":inference_interpreter_delegate_runner",
|
||||||
|
":inference_runner",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
|
||||||
"//mediapipe/util:cpu_util",
|
|
||||||
],
|
|
||||||
}) + select({
|
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
|
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
|
||||||
}),
|
}),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "inference_calculator_utils",
|
||||||
|
srcs = ["inference_calculator_utils.cc"],
|
||||||
|
hdrs = ["inference_calculator_utils.h"],
|
||||||
|
deps = [
|
||||||
|
":inference_calculator_cc_proto",
|
||||||
|
] + select({
|
||||||
|
"//conditions:default": [
|
||||||
|
"//mediapipe/util:cpu_util",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "inference_calculator_xnnpack",
|
||||||
|
srcs = [
|
||||||
|
"inference_calculator_xnnpack.cc",
|
||||||
|
],
|
||||||
|
copts = select({
|
||||||
|
# TODO: fix tensor.h not to require this, if possible
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc", # enable reference-counting
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":inference_calculator_interface",
|
||||||
|
":inference_calculator_utils",
|
||||||
|
":inference_interpreter_delegate_runner",
|
||||||
|
":inference_runner",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "inference_calculator_gl_if_compute_shader_available",
|
name = "inference_calculator_gl_if_compute_shader_available",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
|
|
@ -59,6 +59,7 @@ class InferenceCalculatorSelectorImpl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impls.emplace_back("Cpu");
|
impls.emplace_back("Cpu");
|
||||||
|
impls.emplace_back("Xnnpack");
|
||||||
for (const auto& suffix : impls) {
|
for (const auto& suffix : impls) {
|
||||||
const auto impl = absl::StrCat("InferenceCalculator", suffix);
|
const auto impl = absl::StrCat("InferenceCalculator", suffix);
|
||||||
if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue;
|
if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue;
|
||||||
|
|
|
@ -141,6 +141,10 @@ struct InferenceCalculatorCpu : public InferenceCalculator {
|
||||||
static constexpr char kCalculatorName[] = "InferenceCalculatorCpu";
|
static constexpr char kCalculatorName[] = "InferenceCalculatorCpu";
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct InferenceCalculatorXnnpack : public InferenceCalculator {
|
||||||
|
static constexpr char kCalculatorName[] = "InferenceCalculatorXnnpack";
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -18,78 +18,21 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/interpreter_builder.h"
|
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#endif // ANDROID
|
#endif // ANDROID
|
||||||
|
|
||||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
|
||||||
#include "mediapipe/util/cpu_util.h"
|
|
||||||
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
|
|
||||||
|
|
||||||
#include "tensorflow/lite/c/c_api_types.h"
|
|
||||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
int GetXnnpackDefaultNumThreads() {
|
|
||||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \
|
|
||||||
defined(__EMSCRIPTEN_PTHREADS__)
|
|
||||||
constexpr int kMinNumThreadsByDefault = 1;
|
|
||||||
constexpr int kMaxNumThreadsByDefault = 4;
|
|
||||||
return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault,
|
|
||||||
kMaxNumThreadsByDefault);
|
|
||||||
#else
|
|
||||||
return 1;
|
|
||||||
#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns number of threads to configure XNNPACK delegate with.
|
|
||||||
// Returns user provided value if specified. Otherwise, tries to choose optimal
|
|
||||||
// number of threads depending on the device.
|
|
||||||
int GetXnnpackNumThreads(
|
|
||||||
const bool opts_has_delegate,
|
|
||||||
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
|
|
||||||
static constexpr int kDefaultNumThreads = -1;
|
|
||||||
if (opts_has_delegate && opts_delegate.has_xnnpack() &&
|
|
||||||
opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
|
|
||||||
return opts_delegate.xnnpack().num_threads();
|
|
||||||
}
|
|
||||||
return GetXnnpackDefaultNumThreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
|
||||||
tflite::Interpreter* interpreter,
|
|
||||||
int input_tensor_index) {
|
|
||||||
auto input_tensor_view = input_tensor.GetCpuReadView();
|
|
||||||
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
|
||||||
T* local_tensor_buffer =
|
|
||||||
interpreter->typed_input_tensor<T>(input_tensor_index);
|
|
||||||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
|
||||||
int output_tensor_index,
|
|
||||||
Tensor* output_tensor) {
|
|
||||||
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
|
||||||
auto output_tensor_buffer = output_tensor_view.buffer<T>();
|
|
||||||
T* local_tensor_buffer =
|
|
||||||
interpreter->typed_output_tensor<T>(output_tensor_index);
|
|
||||||
std::memcpy(output_tensor_buffer, local_tensor_buffer,
|
|
||||||
output_tensor->bytes());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
class InferenceCalculatorCpuImpl
|
class InferenceCalculatorCpuImpl
|
||||||
: public NodeImpl<InferenceCalculatorCpu, InferenceCalculatorCpuImpl> {
|
: public NodeImpl<InferenceCalculatorCpu, InferenceCalculatorCpuImpl> {
|
||||||
public:
|
public:
|
||||||
|
@ -100,16 +43,11 @@ class InferenceCalculatorCpuImpl
|
||||||
absl::Status Close(CalculatorContext* cc) override;
|
absl::Status Close(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
absl::StatusOr<std::unique_ptr<InferenceRunner>> CreateInferenceRunner(
|
||||||
absl::Status LoadDelegate(CalculatorContext* cc,
|
CalculatorContext* cc);
|
||||||
tflite::InterpreterBuilder* interpreter_builder);
|
absl::StatusOr<TfLiteDelegatePtr> MaybeCreateDelegate(CalculatorContext* cc);
|
||||||
absl::Status AllocateTensors();
|
|
||||||
|
|
||||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
std::unique_ptr<InferenceRunner> inference_runner_;
|
||||||
Packet<TfLiteModelPtr> model_packet_;
|
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
|
||||||
TfLiteDelegatePtr delegate_;
|
|
||||||
TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
||||||
|
@ -122,7 +60,8 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
|
||||||
return InitInterpreter(cc);
|
ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc));
|
||||||
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
||||||
|
@ -131,123 +70,32 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
const auto& input_tensors = *kInTensors(cc);
|
const auto& input_tensors = *kInTensors(cc);
|
||||||
RET_CHECK(!input_tensors.empty());
|
RET_CHECK(!input_tensors.empty());
|
||||||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
|
||||||
|
|
||||||
if (input_tensor_type_ == kTfLiteNoType) {
|
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
|
||||||
input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type;
|
inference_runner_->Run(input_tensors));
|
||||||
}
|
|
||||||
|
|
||||||
// Read CPU input into tensors.
|
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
|
||||||
switch (input_tensor_type_) {
|
|
||||||
case TfLiteType::kTfLiteFloat16:
|
|
||||||
case TfLiteType::kTfLiteFloat32: {
|
|
||||||
CopyTensorBufferToInterpreter<float>(input_tensors[i],
|
|
||||||
interpreter_.get(), i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case TfLiteType::kTfLiteUInt8: {
|
|
||||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
|
||||||
interpreter_.get(), i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case TfLiteType::kTfLiteInt8: {
|
|
||||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
|
||||||
interpreter_.get(), i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case TfLiteType::kTfLiteInt32: {
|
|
||||||
CopyTensorBufferToInterpreter<int32_t>(input_tensors[i],
|
|
||||||
interpreter_.get(), i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run inference.
|
|
||||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
|
||||||
|
|
||||||
// Output result tensors (CPU).
|
|
||||||
const auto& tensor_indexes = interpreter_->outputs();
|
|
||||||
output_tensors->reserve(tensor_indexes.size());
|
|
||||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
|
||||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
|
||||||
Tensor::Shape shape{std::vector<int>{
|
|
||||||
tensor->dims->data, tensor->dims->data + tensor->dims->size}};
|
|
||||||
switch (tensor->type) {
|
|
||||||
case TfLiteType::kTfLiteFloat16:
|
|
||||||
case TfLiteType::kTfLiteFloat32:
|
|
||||||
output_tensors->emplace_back(Tensor::ElementType::kFloat32, shape);
|
|
||||||
CopyTensorBufferFromInterpreter<float>(interpreter_.get(), i,
|
|
||||||
&output_tensors->back());
|
|
||||||
break;
|
|
||||||
case TfLiteType::kTfLiteUInt8:
|
|
||||||
output_tensors->emplace_back(
|
|
||||||
Tensor::ElementType::kUInt8, shape,
|
|
||||||
Tensor::QuantizationParameters{tensor->params.scale,
|
|
||||||
tensor->params.zero_point});
|
|
||||||
CopyTensorBufferFromInterpreter<uint8>(interpreter_.get(), i,
|
|
||||||
&output_tensors->back());
|
|
||||||
break;
|
|
||||||
case TfLiteType::kTfLiteInt8:
|
|
||||||
output_tensors->emplace_back(
|
|
||||||
Tensor::ElementType::kInt8, shape,
|
|
||||||
Tensor::QuantizationParameters{tensor->params.scale,
|
|
||||||
tensor->params.zero_point});
|
|
||||||
CopyTensorBufferFromInterpreter<int8>(interpreter_.get(), i,
|
|
||||||
&output_tensors->back());
|
|
||||||
break;
|
|
||||||
case TfLiteType::kTfLiteInt32:
|
|
||||||
output_tensors->emplace_back(Tensor::ElementType::kInt32, shape);
|
|
||||||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
|
||||||
&output_tensors->back());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
absl::StrCat("Unsupported output tensor type:",
|
|
||||||
TfLiteTypeGetName(tensor->type)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
kOutTensors(cc).Send(std::move(output_tensors));
|
kOutTensors(cc).Send(std::move(output_tensors));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) {
|
||||||
interpreter_ = nullptr;
|
inference_runner_ = nullptr;
|
||||||
delegate_ = nullptr;
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::InitInterpreter(
|
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||||
CalculatorContext* cc) {
|
InferenceCalculatorCpuImpl::CreateInferenceRunner(CalculatorContext* cc) {
|
||||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc));
|
||||||
const auto& model = *model_packet_.Get();
|
|
||||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||||
const auto& op_resolver = op_resolver_packet.Get();
|
const int interpreter_num_threads =
|
||||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread();
|
||||||
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
|
ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, MaybeCreateDelegate(cc));
|
||||||
#if defined(__EMSCRIPTEN__)
|
return CreateInferenceInterpreterDelegateRunner(
|
||||||
interpreter_builder.SetNumThreads(1);
|
std::move(model_packet), std::move(op_resolver_packet),
|
||||||
#else
|
std::move(delegate), interpreter_num_threads);
|
||||||
interpreter_builder.SetNumThreads(
|
|
||||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
|
||||||
#endif // __EMSCRIPTEN__
|
|
||||||
|
|
||||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
|
||||||
RET_CHECK(interpreter_);
|
|
||||||
return AllocateTensors();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::AllocateTensors() {
|
absl::StatusOr<TfLiteDelegatePtr>
|
||||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
InferenceCalculatorCpuImpl::MaybeCreateDelegate(CalculatorContext* cc) {
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
|
||||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
|
||||||
const auto& calculator_opts =
|
const auto& calculator_opts =
|
||||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||||
auto opts_delegate = calculator_opts.delegate();
|
auto opts_delegate = calculator_opts.delegate();
|
||||||
|
@ -268,7 +116,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
||||||
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
|
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||||
if (opts_has_delegate && opts_delegate.has_tflite()) {
|
if (opts_has_delegate && opts_delegate.has_tflite()) {
|
||||||
// Default tflite inference requeqsted - no need to modify graph.
|
// Default tflite inference requeqsted - no need to modify graph.
|
||||||
return absl::OkStatus();
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
|
@ -288,10 +136,8 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
||||||
options.accelerator_name = nnapi.has_accelerator_name()
|
options.accelerator_name = nnapi.has_accelerator_name()
|
||||||
? nnapi.accelerator_name().c_str()
|
? nnapi.accelerator_name().c_str()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
return TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
||||||
[](TfLiteDelegate*) {});
|
[](TfLiteDelegate*) {});
|
||||||
interpreter_builder->AddDelegate(delegate_.get());
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
}
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
|
|
||||||
|
@ -305,12 +151,11 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
||||||
auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault();
|
auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault();
|
||||||
xnnpack_opts.num_threads =
|
xnnpack_opts.num_threads =
|
||||||
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
||||||
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||||
&TfLiteXNNPackDelegateDelete);
|
&TfLiteXNNPackDelegateDelete);
|
||||||
interpreter_builder->AddDelegate(delegate_.get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
|
|
52
mediapipe/calculators/tensor/inference_calculator_utils.cc
Normal file
52
mediapipe/calculators/tensor/inference_calculator_utils.cc
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||||
|
|
||||||
|
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||||
|
#include "mediapipe/util/cpu_util.h"
|
||||||
|
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
int GetXnnpackDefaultNumThreads() {
|
||||||
|
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \
|
||||||
|
defined(__EMSCRIPTEN_PTHREADS__)
|
||||||
|
constexpr int kMinNumThreadsByDefault = 1;
|
||||||
|
constexpr int kMaxNumThreadsByDefault = 4;
|
||||||
|
return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault,
|
||||||
|
kMaxNumThreadsByDefault);
|
||||||
|
#else
|
||||||
|
return 1;
|
||||||
|
#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int GetXnnpackNumThreads(
|
||||||
|
const bool opts_has_delegate,
|
||||||
|
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
|
||||||
|
static constexpr int kDefaultNumThreads = -1;
|
||||||
|
if (opts_has_delegate && opts_delegate.has_xnnpack() &&
|
||||||
|
opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
|
||||||
|
return opts_delegate.xnnpack().num_threads();
|
||||||
|
}
|
||||||
|
return GetXnnpackDefaultNumThreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
31
mediapipe/calculators/tensor/inference_calculator_utils.h
Normal file
31
mediapipe/calculators/tensor/inference_calculator_utils.h
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// Returns number of threads to configure XNNPACK delegate with.
|
||||||
|
// Returns user provided value if specified. Otherwise, tries to choose optimal
|
||||||
|
// number of threads depending on the device.
|
||||||
|
int GetXnnpackNumThreads(
|
||||||
|
const bool opts_has_delegate,
|
||||||
|
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate);
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
|
122
mediapipe/calculators/tensor/inference_calculator_xnnpack.cc
Normal file
122
mediapipe/calculators/tensor/inference_calculator_xnnpack.cc
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||||
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
class InferenceCalculatorXnnpackImpl
|
||||||
|
: public NodeImpl<InferenceCalculatorXnnpack,
|
||||||
|
InferenceCalculatorXnnpackImpl> {
|
||||||
|
public:
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
absl::Status Close(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::StatusOr<std::unique_ptr<InferenceRunner>> CreateInferenceRunner(
|
||||||
|
CalculatorContext* cc);
|
||||||
|
absl::StatusOr<TfLiteDelegatePtr> CreateDelegate(CalculatorContext* cc);
|
||||||
|
|
||||||
|
std::unique_ptr<InferenceRunner> inference_runner_;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status InferenceCalculatorXnnpackImpl::UpdateContract(
|
||||||
|
CalculatorContract* cc) {
|
||||||
|
const auto& options = cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||||
|
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
|
||||||
|
<< "Either model as side packet or model path in options is required.";
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status InferenceCalculatorXnnpackImpl::Open(CalculatorContext* cc) {
|
||||||
|
ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) {
|
||||||
|
if (kInTensors(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
const auto& input_tensors = *kInTensors(cc);
|
||||||
|
RET_CHECK(!input_tensors.empty());
|
||||||
|
|
||||||
|
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
|
||||||
|
inference_runner_->Run(input_tensors));
|
||||||
|
kOutTensors(cc).Send(std::move(output_tensors));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status InferenceCalculatorXnnpackImpl::Close(CalculatorContext* cc) {
|
||||||
|
inference_runner_ = nullptr;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||||
|
InferenceCalculatorXnnpackImpl::CreateInferenceRunner(CalculatorContext* cc) {
|
||||||
|
ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc));
|
||||||
|
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||||
|
const int interpreter_num_threads =
|
||||||
|
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread();
|
||||||
|
ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, CreateDelegate(cc));
|
||||||
|
return CreateInferenceInterpreterDelegateRunner(
|
||||||
|
std::move(model_packet), std::move(op_resolver_packet),
|
||||||
|
std::move(delegate), interpreter_num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<TfLiteDelegatePtr>
|
||||||
|
InferenceCalculatorXnnpackImpl::CreateDelegate(CalculatorContext* cc) {
|
||||||
|
const auto& calculator_opts =
|
||||||
|
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||||
|
auto opts_delegate = calculator_opts.delegate();
|
||||||
|
if (!kDelegate(cc).IsEmpty()) {
|
||||||
|
const mediapipe::InferenceCalculatorOptions::Delegate&
|
||||||
|
input_side_packet_delegate = kDelegate(cc).Get();
|
||||||
|
RET_CHECK(
|
||||||
|
input_side_packet_delegate.has_xnnpack() ||
|
||||||
|
input_side_packet_delegate.delegate_case() ==
|
||||||
|
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
|
||||||
|
<< "inference_calculator_cpu only supports delegate input side packet "
|
||||||
|
<< "for TFLite, XNNPack";
|
||||||
|
opts_delegate.MergeFrom(input_side_packet_delegate);
|
||||||
|
}
|
||||||
|
const bool opts_has_delegate =
|
||||||
|
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||||
|
|
||||||
|
auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault();
|
||||||
|
xnnpack_opts.num_threads =
|
||||||
|
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
||||||
|
return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||||
|
&TfLiteXNNPackDelegateDelete);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,181 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/interpreter_builder.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||||
|
tflite::Interpreter* interpreter,
|
||||||
|
int input_tensor_index) {
|
||||||
|
auto input_tensor_view = input_tensor.GetCpuReadView();
|
||||||
|
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
||||||
|
T* local_tensor_buffer =
|
||||||
|
interpreter->typed_input_tensor<T>(input_tensor_index);
|
||||||
|
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||||
|
int output_tensor_index,
|
||||||
|
Tensor* output_tensor) {
|
||||||
|
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
||||||
|
auto output_tensor_buffer = output_tensor_view.buffer<T>();
|
||||||
|
T* local_tensor_buffer =
|
||||||
|
interpreter->typed_output_tensor<T>(output_tensor_index);
|
||||||
|
std::memcpy(output_tensor_buffer, local_tensor_buffer,
|
||||||
|
output_tensor->bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
||||||
|
public:
|
||||||
|
InferenceInterpreterDelegateRunner(
|
||||||
|
api2::Packet<TfLiteModelPtr> model,
|
||||||
|
std::unique_ptr<tflite::Interpreter> interpreter,
|
||||||
|
TfLiteDelegatePtr delegate)
|
||||||
|
: model_(std::move(model)),
|
||||||
|
interpreter_(std::move(interpreter)),
|
||||||
|
delegate_(std::move(delegate)) {}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<Tensor>> Run(
|
||||||
|
const std::vector<Tensor>& input_tensors) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
api2::Packet<TfLiteModelPtr> model_;
|
||||||
|
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||||
|
TfLiteDelegatePtr delegate_;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||||
|
const std::vector<Tensor>& input_tensors) {
|
||||||
|
// Read CPU input into tensors.
|
||||||
|
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
||||||
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
|
const TfLiteType input_tensor_type =
|
||||||
|
interpreter_->tensor(interpreter_->inputs()[i])->type;
|
||||||
|
switch (input_tensor_type) {
|
||||||
|
case TfLiteType::kTfLiteFloat16:
|
||||||
|
case TfLiteType::kTfLiteFloat32: {
|
||||||
|
CopyTensorBufferToInterpreter<float>(input_tensors[i],
|
||||||
|
interpreter_.get(), i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TfLiteType::kTfLiteUInt8: {
|
||||||
|
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||||
|
interpreter_.get(), i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TfLiteType::kTfLiteInt8: {
|
||||||
|
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||||
|
interpreter_.get(), i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case TfLiteType::kTfLiteInt32: {
|
||||||
|
CopyTensorBufferToInterpreter<int32_t>(input_tensors[i],
|
||||||
|
interpreter_.get(), i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run inference.
|
||||||
|
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||||
|
|
||||||
|
// Output result tensors (CPU).
|
||||||
|
const auto& tensor_indexes = interpreter_->outputs();
|
||||||
|
std::vector<Tensor> output_tensors;
|
||||||
|
output_tensors.reserve(tensor_indexes.size());
|
||||||
|
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||||
|
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||||
|
Tensor::Shape shape{std::vector<int>{
|
||||||
|
tensor->dims->data, tensor->dims->data + tensor->dims->size}};
|
||||||
|
switch (tensor->type) {
|
||||||
|
case TfLiteType::kTfLiteFloat16:
|
||||||
|
case TfLiteType::kTfLiteFloat32:
|
||||||
|
output_tensors.emplace_back(Tensor::ElementType::kFloat32, shape);
|
||||||
|
CopyTensorBufferFromInterpreter<float>(interpreter_.get(), i,
|
||||||
|
&output_tensors.back());
|
||||||
|
break;
|
||||||
|
case TfLiteType::kTfLiteUInt8:
|
||||||
|
output_tensors.emplace_back(
|
||||||
|
Tensor::ElementType::kUInt8, shape,
|
||||||
|
Tensor::QuantizationParameters{tensor->params.scale,
|
||||||
|
tensor->params.zero_point});
|
||||||
|
CopyTensorBufferFromInterpreter<uint8>(interpreter_.get(), i,
|
||||||
|
&output_tensors.back());
|
||||||
|
break;
|
||||||
|
case TfLiteType::kTfLiteInt8:
|
||||||
|
output_tensors.emplace_back(
|
||||||
|
Tensor::ElementType::kInt8, shape,
|
||||||
|
Tensor::QuantizationParameters{tensor->params.scale,
|
||||||
|
tensor->params.zero_point});
|
||||||
|
CopyTensorBufferFromInterpreter<int8>(interpreter_.get(), i,
|
||||||
|
&output_tensors.back());
|
||||||
|
break;
|
||||||
|
case TfLiteType::kTfLiteInt32:
|
||||||
|
output_tensors.emplace_back(Tensor::ElementType::kInt32, shape);
|
||||||
|
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||||
|
&output_tensors.back());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
absl::StrCat("Unsupported output tensor type:",
|
||||||
|
TfLiteTypeGetName(tensor->type)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output_tensors;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||||
|
CreateInferenceInterpreterDelegateRunner(
|
||||||
|
api2::Packet<TfLiteModelPtr> model,
|
||||||
|
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
||||||
|
int interpreter_num_threads) {
|
||||||
|
tflite::InterpreterBuilder interpreter_builder(*model.Get(),
|
||||||
|
op_resolver.Get());
|
||||||
|
if (delegate) {
|
||||||
|
interpreter_builder.AddDelegate(delegate.get());
|
||||||
|
}
|
||||||
|
#if defined(__EMSCRIPTEN__)
|
||||||
|
interpreter_builder.SetNumThreads(1);
|
||||||
|
#else
|
||||||
|
interpreter_builder.SetNumThreads(interpreter_num_threads);
|
||||||
|
#endif // __EMSCRIPTEN__
|
||||||
|
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||||
|
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
|
||||||
|
RET_CHECK(interpreter);
|
||||||
|
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||||
|
return std::make_unique<InferenceInterpreterDelegateRunner>(
|
||||||
|
std::move(model), std::move(interpreter), std::move(delegate));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,46 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using TfLiteDelegatePtr =
|
||||||
|
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
||||||
|
|
||||||
|
// Creates inference runner which run inference using newly initialized
|
||||||
|
// interpreter and provided `delegate`.
|
||||||
|
//
|
||||||
|
// `delegate` can be nullptr, in that case newly initialized interpreter will
|
||||||
|
// use what is available by default.
|
||||||
|
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||||
|
CreateInferenceInterpreterDelegateRunner(
|
||||||
|
api2::Packet<TfLiteModelPtr> model,
|
||||||
|
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
||||||
|
int interpreter_num_threads);
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_
|
19
mediapipe/calculators/tensor/inference_runner.h
Normal file
19
mediapipe/calculators/tensor/inference_runner.h
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// Common interface to implement inference runners in MediaPipe.
|
||||||
|
class InferenceRunner {
|
||||||
|
public:
|
||||||
|
virtual ~InferenceRunner() = default;
|
||||||
|
virtual absl::StatusOr<std::vector<Tensor>> Run(
|
||||||
|
const std::vector<Tensor>& inputs) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
|
@ -69,7 +69,8 @@ objc_library(
|
||||||
"-Wno-shorten-64-to-32",
|
"-Wno-shorten-64-to-32",
|
||||||
],
|
],
|
||||||
sdk_frameworks = ["Accelerate"],
|
sdk_frameworks = ["Accelerate"],
|
||||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
# This build rule is public to allow external customers to build their own iOS apps.
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":CFHolder",
|
":CFHolder",
|
||||||
":util",
|
":util",
|
||||||
|
@ -124,7 +125,8 @@ objc_library(
|
||||||
"CoreVideo",
|
"CoreVideo",
|
||||||
"Foundation",
|
"Foundation",
|
||||||
],
|
],
|
||||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
# This build rule is public to allow external customers to build their own iOS apps.
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
|
@ -166,7 +168,8 @@ objc_library(
|
||||||
"Foundation",
|
"Foundation",
|
||||||
"GLKit",
|
"GLKit",
|
||||||
],
|
],
|
||||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
# This build rule is public to allow external customers to build their own iOS apps.
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":mediapipe_framework_ios",
|
":mediapipe_framework_ios",
|
||||||
":mediapipe_gl_view_renderer",
|
":mediapipe_gl_view_renderer",
|
||||||
|
|
|
@ -17,21 +17,21 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
_OSS_URL_PREFIX = 'https://github.com/google/mediapipe/raw/master/'
|
_GCS_URL_PREFIX = 'https://storage.googleapis.com/mediapipe-assets/'
|
||||||
|
|
||||||
|
|
||||||
def download_oss_model(model_path: str):
|
def download_oss_model(model_path: str):
|
||||||
"""Downloads the oss model from the MediaPipe GitHub repo if it doesn't exist in the package."""
|
"""Downloads the oss model from Google Cloud Storage if it doesn't exist in the package."""
|
||||||
|
|
||||||
mp_root_path = os.sep.join(os.path.abspath(__file__).split(os.sep)[:-4])
|
mp_root_path = os.sep.join(os.path.abspath(__file__).split(os.sep)[:-4])
|
||||||
model_abspath = os.path.join(mp_root_path, model_path)
|
model_abspath = os.path.join(mp_root_path, model_path)
|
||||||
if os.path.exists(model_abspath):
|
if os.path.exists(model_abspath):
|
||||||
return
|
return
|
||||||
model_url = _OSS_URL_PREFIX + model_path
|
model_url = _GCS_URL_PREFIX + model_path.split('/')[-1]
|
||||||
print('Downloading model to ' + model_abspath)
|
print('Downloading model to ' + model_abspath)
|
||||||
with urllib.request.urlopen(model_url) as response, open(model_abspath,
|
with urllib.request.urlopen(model_url) as response, open(model_abspath,
|
||||||
'wb') as out_file:
|
'wb') as out_file:
|
||||||
if response.code != 200:
|
if response.code != 200:
|
||||||
raise ConnectionError('Cannot download ' + model_path +
|
raise ConnectionError('Cannot download ' + model_path +
|
||||||
' from the MediaPipe Github repo.')
|
' from Google Cloud Storage.')
|
||||||
shutil.copyfileobj(response, out_file)
|
shutil.copyfileobj(response, out_file)
|
||||||
|
|
|
@ -142,7 +142,7 @@ TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||||
EXPECT_EQ(image_classifier_or.status().code(), absl::StatusCode::kInternal);
|
EXPECT_EQ(image_classifier_or.status().code(), absl::StatusCode::kInternal);
|
||||||
EXPECT_THAT(image_classifier_or.status().message(),
|
EXPECT_THAT(image_classifier_or.status().message(),
|
||||||
HasSubstr("interpreter_builder(&interpreter_) == kTfLiteOk"));
|
HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||||
}
|
}
|
||||||
TEST_F(CreateTest, FailsWithMissingModel) {
|
TEST_F(CreateTest, FailsWithMissingModel) {
|
||||||
auto image_classifier_or =
|
auto image_classifier_or =
|
||||||
|
|
|
@ -194,7 +194,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||||
EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal);
|
EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal);
|
||||||
EXPECT_THAT(object_detector.status().message(),
|
EXPECT_THAT(object_detector.status().message(),
|
||||||
HasSubstr("interpreter_->AllocateTensors() == kTfLiteOk"));
|
HasSubstr("interpreter->AllocateTensors() == kTfLiteOk"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
|
|
|
@ -185,7 +185,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
segmenter_or.status().message(),
|
segmenter_or.status().message(),
|
||||||
testing::HasSubstr("interpreter_builder(&interpreter_) == kTfLiteOk"));
|
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user