Depends on TFLite shim header.
PiperOrigin-RevId: 508491302
This commit is contained in:
parent
99fc975f49
commit
fd764dae0a
|
@ -21,6 +21,7 @@ load(
|
|||
)
|
||||
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
||||
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -370,10 +371,15 @@ mediapipe_proto_library(
|
|||
# size concerns), depend on those implementations directly, and do not depend on
|
||||
# :inference_calculator.
|
||||
# In all cases, use "InferenceCalulator" in your graphs.
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "inference_calculator_interface",
|
||||
srcs = ["inference_calculator.cc"],
|
||||
hdrs = ["inference_calculator.h"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
],
|
||||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
":inference_calculator_options_lib",
|
||||
|
@ -384,12 +390,9 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//mediapipe/framework/tool:subgraph_expansion",
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -473,22 +476,24 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "inference_interpreter_delegate_runner",
|
||||
srcs = ["inference_interpreter_delegate_runner.cc"],
|
||||
hdrs = ["inference_interpreter_delegate_runner.h"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
],
|
||||
deps = [
|
||||
":inference_runner",
|
||||
"//mediapipe/framework:mediapipe_profiling",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
@ -506,9 +511,9 @@ cc_library(
|
|||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
|
||||
|
|
|
@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
|
|||
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
||||
}
|
||||
return PacketAdopting<tflite::OpResolver>(
|
||||
std::make_unique<
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
|
||||
std::make_unique<tflite_shims::ops::builtin::
|
||||
BuiltinOpResolverWithoutDefaultDelegates>());
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
|
|||
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
||||
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
||||
// migration.
|
||||
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
||||
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
|
||||
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
||||
"OP_RESOLVER"};
|
||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||
|
@ -112,7 +112,8 @@ class InferenceCalculator : public NodeIntf {
|
|||
|
||||
protected:
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
||||
std::unique_ptr<TfLiteOpaqueDelegate,
|
||||
std::function<void(TfLiteOpaqueDelegate*)>>;
|
||||
|
||||
static absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
|
||||
CalculatorContext* cc);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#endif // ANDROID
|
||||
|
|
|
@ -22,9 +22,9 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
@ -33,9 +33,12 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
using Interpreter = ::tflite_shims::Interpreter;
|
||||
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
auto input_tensor_view = input_tensor.GetCpuReadView();
|
||||
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
||||
|
@ -46,7 +49,7 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
|||
|
||||
template <>
|
||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
const char* input_tensor_buffer =
|
||||
input_tensor.GetCpuReadView().buffer<char>();
|
||||
|
@ -58,7 +61,7 @@ void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
void CopyTensorBufferFromInterpreter(Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
Tensor* output_tensor) {
|
||||
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
||||
|
@ -73,9 +76,8 @@ void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
|||
|
||||
class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
||||
public:
|
||||
InferenceInterpreterDelegateRunner(
|
||||
api2::Packet<TfLiteModelPtr> model,
|
||||
std::unique_ptr<tflite::Interpreter> interpreter,
|
||||
InferenceInterpreterDelegateRunner(api2::Packet<TfLiteModelPtr> model,
|
||||
std::unique_ptr<Interpreter> interpreter,
|
||||
TfLiteDelegatePtr delegate)
|
||||
: model_(std::move(model)),
|
||||
interpreter_(std::move(interpreter)),
|
||||
|
@ -86,7 +88,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
|||
|
||||
private:
|
||||
api2::Packet<TfLiteModelPtr> model_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
std::unique_ptr<Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
};
|
||||
|
||||
|
@ -197,8 +199,7 @@ CreateInferenceInterpreterDelegateRunner(
|
|||
api2::Packet<TfLiteModelPtr> model,
|
||||
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
||||
int interpreter_num_threads) {
|
||||
tflite::InterpreterBuilder interpreter_builder(*model.Get(),
|
||||
op_resolver.Get());
|
||||
InterpreterBuilder interpreter_builder(*model.Get(), op_resolver.Get());
|
||||
if (delegate) {
|
||||
interpreter_builder.AddDelegate(delegate.get());
|
||||
}
|
||||
|
@ -207,7 +208,7 @@ CreateInferenceInterpreterDelegateRunner(
|
|||
#else
|
||||
interpreter_builder.SetNumThreads(interpreter_num_threads);
|
||||
#endif // __EMSCRIPTEN__
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
|
||||
RET_CHECK(interpreter);
|
||||
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
|
|
@ -23,12 +23,14 @@
|
|||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// TODO: Consider renaming TfLiteDelegatePtr.
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
||||
std::unique_ptr<TfLiteOpaqueDelegate,
|
||||
std::function<void(TfLiteOpaqueDelegate*)>>;
|
||||
|
||||
// Creates inference runner which run inference using newly initialized
|
||||
// interpreter and provided `delegate`.
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = [
|
||||
|
@ -110,10 +112,13 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "tflite_model_loader",
|
||||
srcs = ["tflite_model_loader.cc"],
|
||||
hdrs = ["tflite_model_loader.h"],
|
||||
tflite_deps = [
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/api2:packet",
|
||||
|
@ -121,6 +126,5 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/util:resource_util",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using FlatBufferModel = ::tflite_shims::FlatBufferModel;
|
||||
|
||||
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
||||
const std::string& path) {
|
||||
std::string model_path = path;
|
||||
|
@ -36,12 +38,12 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
|||
mediapipe::GetResourceContents(resolved_path, &model_blob));
|
||||
}
|
||||
|
||||
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||
model_blob.data(), model_blob.size());
|
||||
auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(),
|
||||
model_blob.size());
|
||||
RET_CHECK(model) << "Failed to load model from path " << model_path;
|
||||
return api2::MakePacket<TfLiteModelPtr>(
|
||||
model.release(),
|
||||
[model_blob = std::move(model_blob)](tflite::FlatBufferModel* model) {
|
||||
[model_blob = std::move(model_blob)](FlatBufferModel* model) {
|
||||
// It's required that model_blob is deleted only after
|
||||
// model is deleted, hence capturing model_blob.
|
||||
delete model;
|
||||
|
|
|
@ -15,16 +15,20 @@
|
|||
#ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
|
||||
#define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
||||
|
||||
namespace mediapipe {
|
||||
// Represents a TfLite model as a FlatBuffer.
|
||||
using TfLiteModelPtr =
|
||||
std::unique_ptr<tflite::FlatBufferModel,
|
||||
std::function<void(tflite::FlatBufferModel*)>>;
|
||||
std::unique_ptr<tflite_shims::FlatBufferModel,
|
||||
std::function<void(tflite_shims::FlatBufferModel*)>>;
|
||||
|
||||
class TfLiteModelLoader {
|
||||
public:
|
||||
|
|
Loading…
Reference in New Issue
Block a user