Depends on TFLite shim header.

PiperOrigin-RevId: 508491302
This commit is contained in:
MediaPipe Team 2023-02-09 15:25:20 -08:00 committed by Copybara-Service
parent 99fc975f49
commit fd764dae0a
9 changed files with 60 additions and 41 deletions

View File

@ -21,6 +21,7 @@ load(
) )
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
licenses(["notice"]) licenses(["notice"])
@ -370,10 +371,15 @@ mediapipe_proto_library(
# size concerns), depend on those implementations directly, and do not depend on # size concerns), depend on those implementations directly, and do not depend on
# :inference_calculator. # :inference_calculator.
# In all cases, use "InferenceCalulator" in your graphs. # In all cases, use "InferenceCalulator" in your graphs.
cc_library( cc_library_with_tflite(
name = "inference_calculator_interface", name = "inference_calculator_interface",
srcs = ["inference_calculator.cc"], srcs = ["inference_calculator.cc"],
hdrs = ["inference_calculator.h"], hdrs = ["inference_calculator.h"],
tflite_deps = [
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [ deps = [
":inference_calculator_cc_proto", ":inference_calculator_cc_proto",
":inference_calculator_options_lib", ":inference_calculator_options_lib",
@ -384,12 +390,9 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:subgraph_expansion", "//mediapipe/framework/tool:subgraph_expansion",
"//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -473,22 +476,24 @@ cc_library(
], ],
) )
cc_library( cc_library_with_tflite(
name = "inference_interpreter_delegate_runner", name = "inference_interpreter_delegate_runner",
srcs = ["inference_interpreter_delegate_runner.cc"], srcs = ["inference_interpreter_delegate_runner.cc"],
hdrs = ["inference_interpreter_delegate_runner.h"], hdrs = ["inference_interpreter_delegate_runner.h"],
tflite_deps = [
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
],
deps = [ deps = [
":inference_runner", ":inference_runner",
"//mediapipe/framework:mediapipe_profiling", "//mediapipe/framework:mediapipe_profiling",
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util", "@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
], ],
) )
@ -506,9 +511,9 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"], "//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],

View File

@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>(); return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
} }
return PacketAdopting<tflite::OpResolver>( return PacketAdopting<tflite::OpResolver>(
std::make_unique< std::make_unique<tflite_shims::ops::builtin::
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>()); BuiltinOpResolverWithoutDefaultDelegates>());
} }
} // namespace api2 } // namespace api2

View File

@ -26,7 +26,7 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/util/tflite/tflite_model_loader.h" #include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead. // Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the // TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
// migration. // migration.
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{ static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
"OP_RESOLVER"}; "OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"}; static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
@ -112,7 +112,8 @@ class InferenceCalculator : public NodeIntf {
protected: protected:
using TfLiteDelegatePtr = using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>; std::unique_ptr<TfLiteOpaqueDelegate,
std::function<void(TfLiteOpaqueDelegate*)>>;
static absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket( static absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
CalculatorContext* cc); CalculatorContext* cc);

View File

@ -24,7 +24,7 @@
#include "mediapipe/calculators/tensor/inference_calculator_utils.h" #include "mediapipe/calculators/tensor/inference_calculator_utils.h"
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" #include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
#include "mediapipe/calculators/tensor/inference_runner.h" #include "mediapipe/calculators/tensor/inference_runner.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/core/shims/cc/interpreter.h"
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID)
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#endif // ANDROID #endif // ANDROID

View File

@ -22,9 +22,9 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/shims/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/core/shims/cc/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
@ -33,9 +33,12 @@ namespace mediapipe {
namespace { namespace {
using Interpreter = ::tflite_shims::Interpreter;
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
template <typename T> template <typename T>
void CopyTensorBufferToInterpreter(const Tensor& input_tensor, void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
tflite::Interpreter* interpreter, Interpreter* interpreter,
int input_tensor_index) { int input_tensor_index) {
auto input_tensor_view = input_tensor.GetCpuReadView(); auto input_tensor_view = input_tensor.GetCpuReadView();
auto input_tensor_buffer = input_tensor_view.buffer<T>(); auto input_tensor_buffer = input_tensor_view.buffer<T>();
@ -46,7 +49,7 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
template <> template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor, void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter, Interpreter* interpreter,
int input_tensor_index) { int input_tensor_index) {
const char* input_tensor_buffer = const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>(); input_tensor.GetCpuReadView().buffer<char>();
@ -58,7 +61,7 @@ void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
} }
template <typename T> template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, void CopyTensorBufferFromInterpreter(Interpreter* interpreter,
int output_tensor_index, int output_tensor_index,
Tensor* output_tensor) { Tensor* output_tensor) {
auto output_tensor_view = output_tensor->GetCpuWriteView(); auto output_tensor_view = output_tensor->GetCpuWriteView();
@ -73,10 +76,9 @@ void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
class InferenceInterpreterDelegateRunner : public InferenceRunner { class InferenceInterpreterDelegateRunner : public InferenceRunner {
public: public:
InferenceInterpreterDelegateRunner( InferenceInterpreterDelegateRunner(api2::Packet<TfLiteModelPtr> model,
api2::Packet<TfLiteModelPtr> model, std::unique_ptr<Interpreter> interpreter,
std::unique_ptr<tflite::Interpreter> interpreter, TfLiteDelegatePtr delegate)
TfLiteDelegatePtr delegate)
: model_(std::move(model)), : model_(std::move(model)),
interpreter_(std::move(interpreter)), interpreter_(std::move(interpreter)),
delegate_(std::move(delegate)) {} delegate_(std::move(delegate)) {}
@ -86,7 +88,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
private: private:
api2::Packet<TfLiteModelPtr> model_; api2::Packet<TfLiteModelPtr> model_;
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<Interpreter> interpreter_;
TfLiteDelegatePtr delegate_; TfLiteDelegatePtr delegate_;
}; };
@ -197,8 +199,7 @@ CreateInferenceInterpreterDelegateRunner(
api2::Packet<TfLiteModelPtr> model, api2::Packet<TfLiteModelPtr> model,
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate, api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
int interpreter_num_threads) { int interpreter_num_threads) {
tflite::InterpreterBuilder interpreter_builder(*model.Get(), InterpreterBuilder interpreter_builder(*model.Get(), op_resolver.Get());
op_resolver.Get());
if (delegate) { if (delegate) {
interpreter_builder.AddDelegate(delegate.get()); interpreter_builder.AddDelegate(delegate.get());
} }
@ -207,7 +208,7 @@ CreateInferenceInterpreterDelegateRunner(
#else #else
interpreter_builder.SetNumThreads(interpreter_num_threads); interpreter_builder.SetNumThreads(interpreter_num_threads);
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
std::unique_ptr<tflite::Interpreter> interpreter; std::unique_ptr<Interpreter> interpreter;
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk); RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
RET_CHECK(interpreter); RET_CHECK(interpreter);
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);

View File

@ -23,12 +23,14 @@
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/util/tflite/tflite_model_loader.h" #include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/core/shims/c/c_api_types.h"
namespace mediapipe { namespace mediapipe {
// TODO: Consider renaming TfLiteDelegatePtr.
using TfLiteDelegatePtr = using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>; std::unique_ptr<TfLiteOpaqueDelegate,
std::function<void(TfLiteOpaqueDelegate*)>>;
// Creates inference runner which run inference using newly initialized // Creates inference runner which run inference using newly initialized
// interpreter and provided `delegate`. // interpreter and provided `delegate`.

View File

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# #
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
licenses(["notice"]) licenses(["notice"])
package(default_visibility = [ package(default_visibility = [
@ -110,10 +112,13 @@ cc_library(
], ],
) )
cc_library( cc_library_with_tflite(
name = "tflite_model_loader", name = "tflite_model_loader",
srcs = ["tflite_model_loader.cc"], srcs = ["tflite_model_loader.cc"],
hdrs = ["tflite_model_loader.h"], hdrs = ["tflite_model_loader.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
@ -121,6 +126,5 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework",
], ],
) )

View File

@ -19,6 +19,8 @@
namespace mediapipe { namespace mediapipe {
using FlatBufferModel = ::tflite_shims::FlatBufferModel;
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath( absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
const std::string& path) { const std::string& path) {
std::string model_path = path; std::string model_path = path;
@ -36,12 +38,12 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
mediapipe::GetResourceContents(resolved_path, &model_blob)); mediapipe::GetResourceContents(resolved_path, &model_blob));
} }
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(),
model_blob.data(), model_blob.size()); model_blob.size());
RET_CHECK(model) << "Failed to load model from path " << model_path; RET_CHECK(model) << "Failed to load model from path " << model_path;
return api2::MakePacket<TfLiteModelPtr>( return api2::MakePacket<TfLiteModelPtr>(
model.release(), model.release(),
[model_blob = std::move(model_blob)](tflite::FlatBufferModel* model) { [model_blob = std::move(model_blob)](FlatBufferModel* model) {
// It's required that model_blob is deleted only after // It's required that model_blob is deleted only after
// model is deleted, hence capturing model_blob. // model is deleted, hence capturing model_blob.
delete model; delete model;

View File

@ -15,16 +15,20 @@
#ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_ #ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
#define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_ #define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
#include <functional>
#include <memory>
#include <string>
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/core/shims/cc/model.h"
namespace mediapipe { namespace mediapipe {
// Represents a TfLite model as a FlatBuffer. // Represents a TfLite model as a FlatBuffer.
using TfLiteModelPtr = using TfLiteModelPtr =
std::unique_ptr<tflite::FlatBufferModel, std::unique_ptr<tflite_shims::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>; std::function<void(tflite_shims::FlatBufferModel*)>>;
class TfLiteModelLoader { class TfLiteModelLoader {
public: public: