Project import generated by Copybara.
GitOrigin-RevId: 3ce19771d2586aeb611fff75bb7627721cf5d36b
This commit is contained in:
parent
4dc4b19ddb
commit
d3f98334bf
|
@ -51,6 +51,8 @@ RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /u
|
|||
RUN pip3 install --upgrade setuptools
|
||||
RUN pip3 install wheel
|
||||
RUN pip3 install future
|
||||
RUN pip3 install absl-py
|
||||
RUN pip3 install numpy
|
||||
RUN pip3 install six==1.14.0
|
||||
RUN pip3 install tensorflow==2.2.0
|
||||
RUN pip3 install tf_slim
|
||||
|
|
|
@ -216,6 +216,50 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_runner",
|
||||
hdrs = ["inference_runner.h"],
|
||||
copts = select({
|
||||
# TODO: fix tensor.h not to require this, if possible
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_interpreter_delegate_runner",
|
||||
srcs = ["inference_interpreter_delegate_runner.cc"],
|
||||
hdrs = ["inference_interpreter_delegate_runner.h"],
|
||||
copts = select({
|
||||
# TODO: fix tensor.h not to require this, if possible
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_runner",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_calculator_cpu",
|
||||
srcs = [
|
||||
|
@ -232,22 +276,63 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
":inference_calculator_utils",
|
||||
":inference_interpreter_delegate_runner",
|
||||
":inference_runner",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//mediapipe/util:cpu_util",
|
||||
],
|
||||
}) + select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_calculator_utils",
|
||||
srcs = ["inference_calculator_utils.cc"],
|
||||
hdrs = ["inference_calculator_utils.h"],
|
||||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//mediapipe/util:cpu_util",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_calculator_xnnpack",
|
||||
srcs = [
|
||||
"inference_calculator_xnnpack.cc",
|
||||
],
|
||||
copts = select({
|
||||
# TODO: fix tensor.h not to require this, if possible
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
"-fobjc-arc", # enable reference-counting
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
":inference_calculator_utils",
|
||||
":inference_interpreter_delegate_runner",
|
||||
":inference_runner",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_calculator_gl_if_compute_shader_available",
|
||||
visibility = ["//visibility:public"],
|
||||
|
|
|
@ -59,6 +59,7 @@ class InferenceCalculatorSelectorImpl
|
|||
}
|
||||
}
|
||||
impls.emplace_back("Cpu");
|
||||
impls.emplace_back("Xnnpack");
|
||||
for (const auto& suffix : impls) {
|
||||
const auto impl = absl::StrCat("InferenceCalculator", suffix);
|
||||
if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue;
|
||||
|
|
|
@ -141,6 +141,10 @@ struct InferenceCalculatorCpu : public InferenceCalculator {
|
|||
static constexpr char kCalculatorName[] = "InferenceCalculatorCpu";
|
||||
};
|
||||
|
||||
struct InferenceCalculatorXnnpack : public InferenceCalculator {
|
||||
static constexpr char kCalculatorName[] = "InferenceCalculatorXnnpack";
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -18,78 +18,21 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#endif // ANDROID
|
||||
|
||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||
#include "mediapipe/util/cpu_util.h"
|
||||
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
|
||||
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
int GetXnnpackDefaultNumThreads() {
|
||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \
|
||||
defined(__EMSCRIPTEN_PTHREADS__)
|
||||
constexpr int kMinNumThreadsByDefault = 1;
|
||||
constexpr int kMaxNumThreadsByDefault = 4;
|
||||
return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault,
|
||||
kMaxNumThreadsByDefault);
|
||||
#else
|
||||
return 1;
|
||||
#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__
|
||||
}
|
||||
|
||||
// Returns number of threads to configure XNNPACK delegate with.
|
||||
// Returns user provided value if specified. Otherwise, tries to choose optimal
|
||||
// number of threads depending on the device.
|
||||
int GetXnnpackNumThreads(
|
||||
const bool opts_has_delegate,
|
||||
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
|
||||
static constexpr int kDefaultNumThreads = -1;
|
||||
if (opts_has_delegate && opts_delegate.has_xnnpack() &&
|
||||
opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
|
||||
return opts_delegate.xnnpack().num_threads();
|
||||
}
|
||||
return GetXnnpackDefaultNumThreads();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
auto input_tensor_view = input_tensor.GetCpuReadView();
|
||||
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
||||
T* local_tensor_buffer =
|
||||
interpreter->typed_input_tensor<T>(input_tensor_index);
|
||||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
Tensor* output_tensor) {
|
||||
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
||||
auto output_tensor_buffer = output_tensor_view.buffer<T>();
|
||||
T* local_tensor_buffer =
|
||||
interpreter->typed_output_tensor<T>(output_tensor_index);
|
||||
std::memcpy(output_tensor_buffer, local_tensor_buffer,
|
||||
output_tensor->bytes());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class InferenceCalculatorCpuImpl
|
||||
: public NodeImpl<InferenceCalculatorCpu, InferenceCalculatorCpuImpl> {
|
||||
public:
|
||||
|
@ -100,16 +43,11 @@ class InferenceCalculatorCpuImpl
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc,
|
||||
tflite::InterpreterBuilder* interpreter_builder);
|
||||
absl::Status AllocateTensors();
|
||||
absl::StatusOr<std::unique_ptr<InferenceRunner>> CreateInferenceRunner(
|
||||
CalculatorContext* cc);
|
||||
absl::StatusOr<TfLiteDelegatePtr> MaybeCreateDelegate(CalculatorContext* cc);
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
Packet<TfLiteModelPtr> model_packet_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType;
|
||||
std::unique_ptr<InferenceRunner> inference_runner_;
|
||||
};
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
||||
|
@ -122,7 +60,8 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
|||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
|
||||
return InitInterpreter(cc);
|
||||
ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
||||
|
@ -131,123 +70,32 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
|||
}
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
|
||||
if (input_tensor_type_ == kTfLiteNoType) {
|
||||
input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type;
|
||||
}
|
||||
|
||||
// Read CPU input into tensors.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
switch (input_tensor_type_) {
|
||||
case TfLiteType::kTfLiteFloat16:
|
||||
case TfLiteType::kTfLiteFloat32: {
|
||||
CopyTensorBufferToInterpreter<float>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt32: {
|
||||
CopyTensorBufferToInterpreter<int32_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type_));
|
||||
}
|
||||
}
|
||||
|
||||
// Run inference.
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
// Output result tensors (CPU).
|
||||
const auto& tensor_indexes = interpreter_->outputs();
|
||||
output_tensors->reserve(tensor_indexes.size());
|
||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||
Tensor::Shape shape{std::vector<int>{
|
||||
tensor->dims->data, tensor->dims->data + tensor->dims->size}};
|
||||
switch (tensor->type) {
|
||||
case TfLiteType::kTfLiteFloat16:
|
||||
case TfLiteType::kTfLiteFloat32:
|
||||
output_tensors->emplace_back(Tensor::ElementType::kFloat32, shape);
|
||||
CopyTensorBufferFromInterpreter<float>(interpreter_.get(), i,
|
||||
&output_tensors->back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteUInt8:
|
||||
output_tensors->emplace_back(
|
||||
Tensor::ElementType::kUInt8, shape,
|
||||
Tensor::QuantizationParameters{tensor->params.scale,
|
||||
tensor->params.zero_point});
|
||||
CopyTensorBufferFromInterpreter<uint8>(interpreter_.get(), i,
|
||||
&output_tensors->back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteInt8:
|
||||
output_tensors->emplace_back(
|
||||
Tensor::ElementType::kInt8, shape,
|
||||
Tensor::QuantizationParameters{tensor->params.scale,
|
||||
tensor->params.zero_point});
|
||||
CopyTensorBufferFromInterpreter<int8>(interpreter_.get(), i,
|
||||
&output_tensors->back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteInt32:
|
||||
output_tensors->emplace_back(Tensor::ElementType::kInt32, shape);
|
||||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||
&output_tensors->back());
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported output tensor type:",
|
||||
TfLiteTypeGetName(tensor->type)));
|
||||
}
|
||||
}
|
||||
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
|
||||
inference_runner_->Run(input_tensors));
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) {
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
inference_runner_ = nullptr;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::InitInterpreter(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||
InferenceCalculatorCpuImpl::CreateInferenceRunner(CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc));
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_builder.SetNumThreads(1);
|
||||
#else
|
||||
interpreter_builder.SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
||||
RET_CHECK(interpreter_);
|
||||
return AllocateTensors();
|
||||
const int interpreter_num_threads =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread();
|
||||
ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, MaybeCreateDelegate(cc));
|
||||
return CreateInferenceInterpreterDelegateRunner(
|
||||
std::move(model_packet), std::move(op_resolver_packet),
|
||||
std::move(delegate), interpreter_num_threads);
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::AllocateTensors() {
|
||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
absl::StatusOr<TfLiteDelegatePtr>
|
||||
InferenceCalculatorCpuImpl::MaybeCreateDelegate(CalculatorContext* cc) {
|
||||
const auto& calculator_opts =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
auto opts_delegate = calculator_opts.delegate();
|
||||
|
@ -268,7 +116,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
|||
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||
if (opts_has_delegate && opts_delegate.has_tflite()) {
|
||||
// Default tflite inference requeqsted - no need to modify graph.
|
||||
return absl::OkStatus();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
|
@ -288,10 +136,8 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
|||
options.accelerator_name = nnapi.has_accelerator_name()
|
||||
? nnapi.accelerator_name().c_str()
|
||||
: nullptr;
|
||||
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
||||
[](TfLiteDelegate*) {});
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
return absl::OkStatus();
|
||||
return TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
||||
[](TfLiteDelegate*) {});
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
|
||||
|
@ -305,12 +151,11 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
|||
auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault();
|
||||
xnnpack_opts.num_threads =
|
||||
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
||||
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||
&TfLiteXNNPackDelegateDelete);
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||
&TfLiteXNNPackDelegateDelete);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
|
|
52
mediapipe/calculators/tensor/inference_calculator_utils.cc
Normal file
52
mediapipe/calculators/tensor/inference_calculator_utils.cc
Normal file
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
|
||||
#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
|
||||
#include "mediapipe/util/cpu_util.h"
|
||||
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
int GetXnnpackDefaultNumThreads() {
|
||||
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \
|
||||
defined(__EMSCRIPTEN_PTHREADS__)
|
||||
constexpr int kMinNumThreadsByDefault = 1;
|
||||
constexpr int kMaxNumThreadsByDefault = 4;
|
||||
return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault,
|
||||
kMaxNumThreadsByDefault);
|
||||
#else
|
||||
return 1;
|
||||
#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int GetXnnpackNumThreads(
|
||||
const bool opts_has_delegate,
|
||||
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
|
||||
static constexpr int kDefaultNumThreads = -1;
|
||||
if (opts_has_delegate && opts_delegate.has_xnnpack() &&
|
||||
opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
|
||||
return opts_delegate.xnnpack().num_threads();
|
||||
}
|
||||
return GetXnnpackDefaultNumThreads();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
31
mediapipe/calculators/tensor/inference_calculator_utils.h
Normal file
31
mediapipe/calculators/tensor/inference_calculator_utils.h
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
|
||||
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
|
||||
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Returns number of threads to configure XNNPACK delegate with.
|
||||
// Returns user provided value if specified. Otherwise, tries to choose optimal
|
||||
// number of threads depending on the device.
|
||||
int GetXnnpackNumThreads(
|
||||
const bool opts_has_delegate,
|
||||
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
|
122
mediapipe/calculators/tensor/inference_calculator_xnnpack.cc
Normal file
122
mediapipe/calculators/tensor/inference_calculator_xnnpack.cc
Normal file
|
@ -0,0 +1,122 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
class InferenceCalculatorXnnpackImpl
|
||||
: public NodeImpl<InferenceCalculatorXnnpack,
|
||||
InferenceCalculatorXnnpackImpl> {
|
||||
public:
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::StatusOr<std::unique_ptr<InferenceRunner>> CreateInferenceRunner(
|
||||
CalculatorContext* cc);
|
||||
absl::StatusOr<TfLiteDelegatePtr> CreateDelegate(CalculatorContext* cc);
|
||||
|
||||
std::unique_ptr<InferenceRunner> inference_runner_;
|
||||
};
|
||||
|
||||
absl::Status InferenceCalculatorXnnpackImpl::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
|
||||
<< "Either model as side packet or model path in options is required.";
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorXnnpackImpl::Open(CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) {
|
||||
if (kInTensors(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
|
||||
ASSIGN_OR_RETURN(std::vector<Tensor> output_tensors,
|
||||
inference_runner_->Run(input_tensors));
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorXnnpackImpl::Close(CalculatorContext* cc) {
|
||||
inference_runner_ = nullptr;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||
InferenceCalculatorXnnpackImpl::CreateInferenceRunner(CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc));
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const int interpreter_num_threads =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread();
|
||||
ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, CreateDelegate(cc));
|
||||
return CreateInferenceInterpreterDelegateRunner(
|
||||
std::move(model_packet), std::move(op_resolver_packet),
|
||||
std::move(delegate), interpreter_num_threads);
|
||||
}
|
||||
|
||||
absl::StatusOr<TfLiteDelegatePtr>
|
||||
InferenceCalculatorXnnpackImpl::CreateDelegate(CalculatorContext* cc) {
|
||||
const auto& calculator_opts =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
auto opts_delegate = calculator_opts.delegate();
|
||||
if (!kDelegate(cc).IsEmpty()) {
|
||||
const mediapipe::InferenceCalculatorOptions::Delegate&
|
||||
input_side_packet_delegate = kDelegate(cc).Get();
|
||||
RET_CHECK(
|
||||
input_side_packet_delegate.has_xnnpack() ||
|
||||
input_side_packet_delegate.delegate_case() ==
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
|
||||
<< "inference_calculator_cpu only supports delegate input side packet "
|
||||
<< "for TFLite, XNNPack";
|
||||
opts_delegate.MergeFrom(input_side_packet_delegate);
|
||||
}
|
||||
const bool opts_has_delegate =
|
||||
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||
|
||||
auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault();
|
||||
xnnpack_opts.num_threads =
|
||||
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
||||
return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||
&TfLiteXNNPackDelegateDelete);
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,181 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
auto input_tensor_view = input_tensor.GetCpuReadView();
|
||||
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
||||
T* local_tensor_buffer =
|
||||
interpreter->typed_input_tensor<T>(input_tensor_index);
|
||||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
Tensor* output_tensor) {
|
||||
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
||||
auto output_tensor_buffer = output_tensor_view.buffer<T>();
|
||||
T* local_tensor_buffer =
|
||||
interpreter->typed_output_tensor<T>(output_tensor_index);
|
||||
std::memcpy(output_tensor_buffer, local_tensor_buffer,
|
||||
output_tensor->bytes());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
||||
public:
|
||||
InferenceInterpreterDelegateRunner(
|
||||
api2::Packet<TfLiteModelPtr> model,
|
||||
std::unique_ptr<tflite::Interpreter> interpreter,
|
||||
TfLiteDelegatePtr delegate)
|
||||
: model_(std::move(model)),
|
||||
interpreter_(std::move(interpreter)),
|
||||
delegate_(std::move(delegate)) {}
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> Run(
|
||||
const std::vector<Tensor>& input_tensors) override;
|
||||
|
||||
private:
|
||||
api2::Packet<TfLiteModelPtr> model_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
};
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||
const std::vector<Tensor>& input_tensors) {
|
||||
// Read CPU input into tensors.
|
||||
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
const TfLiteType input_tensor_type =
|
||||
interpreter_->tensor(interpreter_->inputs()[i])->type;
|
||||
switch (input_tensor_type) {
|
||||
case TfLiteType::kTfLiteFloat16:
|
||||
case TfLiteType::kTfLiteFloat32: {
|
||||
CopyTensorBufferToInterpreter<float>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt32: {
|
||||
CopyTensorBufferToInterpreter<int32_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||
}
|
||||
}
|
||||
|
||||
// Run inference.
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
// Output result tensors (CPU).
|
||||
const auto& tensor_indexes = interpreter_->outputs();
|
||||
std::vector<Tensor> output_tensors;
|
||||
output_tensors.reserve(tensor_indexes.size());
|
||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||
Tensor::Shape shape{std::vector<int>{
|
||||
tensor->dims->data, tensor->dims->data + tensor->dims->size}};
|
||||
switch (tensor->type) {
|
||||
case TfLiteType::kTfLiteFloat16:
|
||||
case TfLiteType::kTfLiteFloat32:
|
||||
output_tensors.emplace_back(Tensor::ElementType::kFloat32, shape);
|
||||
CopyTensorBufferFromInterpreter<float>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteUInt8:
|
||||
output_tensors.emplace_back(
|
||||
Tensor::ElementType::kUInt8, shape,
|
||||
Tensor::QuantizationParameters{tensor->params.scale,
|
||||
tensor->params.zero_point});
|
||||
CopyTensorBufferFromInterpreter<uint8>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteInt8:
|
||||
output_tensors.emplace_back(
|
||||
Tensor::ElementType::kInt8, shape,
|
||||
Tensor::QuantizationParameters{tensor->params.scale,
|
||||
tensor->params.zero_point});
|
||||
CopyTensorBufferFromInterpreter<int8>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteInt32:
|
||||
output_tensors.emplace_back(Tensor::ElementType::kInt32, shape);
|
||||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported output tensor type:",
|
||||
TfLiteTypeGetName(tensor->type)));
|
||||
}
|
||||
}
|
||||
return output_tensors;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||
CreateInferenceInterpreterDelegateRunner(
|
||||
api2::Packet<TfLiteModelPtr> model,
|
||||
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
||||
int interpreter_num_threads) {
|
||||
tflite::InterpreterBuilder interpreter_builder(*model.Get(),
|
||||
op_resolver.Get());
|
||||
if (delegate) {
|
||||
interpreter_builder.AddDelegate(delegate.get());
|
||||
}
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_builder.SetNumThreads(1);
|
||||
#else
|
||||
interpreter_builder.SetNumThreads(interpreter_num_threads);
|
||||
#endif // __EMSCRIPTEN__
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
|
||||
RET_CHECK(interpreter);
|
||||
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
return std::make_unique<InferenceInterpreterDelegateRunner>(
|
||||
std::move(model), std::move(interpreter), std::move(delegate));
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_
|
||||
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
||||
|
||||
// Creates inference runner which run inference using newly initialized
|
||||
// interpreter and provided `delegate`.
|
||||
//
|
||||
// `delegate` can be nullptr, in that case newly initialized interpreter will
|
||||
// use what is available by default.
|
||||
absl::StatusOr<std::unique_ptr<InferenceRunner>>
|
||||
CreateInferenceInterpreterDelegateRunner(
|
||||
api2::Packet<TfLiteModelPtr> model,
|
||||
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
||||
int interpreter_num_threads);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_
|
19
mediapipe/calculators/tensor/inference_runner.h
Normal file
19
mediapipe/calculators/tensor/inference_runner.h
Normal file
|
@ -0,0 +1,19 @@
|
|||
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
||||
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Common interface to implement inference runners in MediaPipe.
|
||||
class InferenceRunner {
|
||||
public:
|
||||
virtual ~InferenceRunner() = default;
|
||||
virtual absl::StatusOr<std::vector<Tensor>> Run(
|
||||
const std::vector<Tensor>& inputs) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_
|
|
@ -69,7 +69,8 @@ objc_library(
|
|||
"-Wno-shorten-64-to-32",
|
||||
],
|
||||
sdk_frameworks = ["Accelerate"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":CFHolder",
|
||||
":util",
|
||||
|
@ -124,7 +125,8 @@ objc_library(
|
|||
"CoreVideo",
|
||||
"Foundation",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
|
@ -166,7 +168,8 @@ objc_library(
|
|||
"Foundation",
|
||||
"GLKit",
|
||||
],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
# This build rule is public to allow external customers to build their own iOS apps.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":mediapipe_framework_ios",
|
||||
":mediapipe_gl_view_renderer",
|
||||
|
|
|
@ -17,21 +17,21 @@ import os
|
|||
import shutil
|
||||
import urllib.request
|
||||
|
||||
_OSS_URL_PREFIX = 'https://github.com/google/mediapipe/raw/master/'
|
||||
_GCS_URL_PREFIX = 'https://storage.googleapis.com/mediapipe-assets/'
|
||||
|
||||
|
||||
def download_oss_model(model_path: str):
|
||||
"""Downloads the oss model from the MediaPipe GitHub repo if it doesn't exist in the package."""
|
||||
"""Downloads the oss model from Google Cloud Storage if it doesn't exist in the package."""
|
||||
|
||||
mp_root_path = os.sep.join(os.path.abspath(__file__).split(os.sep)[:-4])
|
||||
model_abspath = os.path.join(mp_root_path, model_path)
|
||||
if os.path.exists(model_abspath):
|
||||
return
|
||||
model_url = _OSS_URL_PREFIX + model_path
|
||||
model_url = _GCS_URL_PREFIX + model_path.split('/')[-1]
|
||||
print('Downloading model to ' + model_abspath)
|
||||
with urllib.request.urlopen(model_url) as response, open(model_abspath,
|
||||
'wb') as out_file:
|
||||
if response.code != 200:
|
||||
raise ConnectionError('Cannot download ' + model_path +
|
||||
' from the MediaPipe Github repo.')
|
||||
' from Google Cloud Storage.')
|
||||
shutil.copyfileobj(response, out_file)
|
||||
|
|
|
@ -142,7 +142,7 @@ TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) {
|
|||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
EXPECT_EQ(image_classifier_or.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(image_classifier_or.status().message(),
|
||||
HasSubstr("interpreter_builder(&interpreter_) == kTfLiteOk"));
|
||||
HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||
}
|
||||
TEST_F(CreateTest, FailsWithMissingModel) {
|
||||
auto image_classifier_or =
|
||||
|
|
|
@ -194,7 +194,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
|||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(object_detector.status().message(),
|
||||
HasSubstr("interpreter_->AllocateTensors() == kTfLiteOk"));
|
||||
HasSubstr("interpreter->AllocateTensors() == kTfLiteOk"));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
|
|
|
@ -185,7 +185,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
|||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
testing::HasSubstr("interpreter_builder(&interpreter_) == kTfLiteOk"));
|
||||
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user