diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 69d666092..1ac5644c1 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -21,6 +21,7 @@ load( ) load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") licenses(["notice"]) @@ -370,10 +371,15 @@ mediapipe_proto_library( # size concerns), depend on those implementations directly, and do not depend on # :inference_calculator. # In all cases, use "InferenceCalulator" in your graphs. -cc_library( +cc_library_with_tflite( name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], hdrs = ["inference_calculator.h"], + tflite_deps = [ + "//mediapipe/util/tflite:tflite_model_loader", + "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + ], deps = [ ":inference_calculator_cc_proto", ":inference_calculator_options_lib", @@ -384,12 +390,9 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/tool:subgraph_expansion", - "//mediapipe/util/tflite:tflite_model_loader", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], alwayslink = 1, ) @@ -473,22 +476,24 @@ cc_library( ], ) -cc_library( +cc_library_with_tflite( name = "inference_interpreter_delegate_runner", srcs = ["inference_interpreter_delegate_runner.cc"], hdrs = ["inference_interpreter_delegate_runner.h"], + tflite_deps = [ + "//mediapipe/util/tflite:tflite_model_loader", + "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", + "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + ], deps = [ ":inference_runner", "//mediapipe/framework:mediapipe_profiling", "//mediapipe/framework/api2:packet", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", - "//mediapipe/util/tflite:tflite_model_loader", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite:string_util", - "@org_tensorflow//tensorflow/lite/c:c_api_types", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) @@ -506,9 +511,9 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", + "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", - "@org_tensorflow//tensorflow/lite:framework_stable", - "@org_tensorflow//tensorflow/lite/c:c_api_types", ] + select({ "//conditions:default": [], "//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"], diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 2a6936eba..c53c6e3d5 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) { return kSideInCustomOpResolver(cc).As(); } return PacketAdopting( - std::make_unique< - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>()); + std::make_unique()); } } // namespace api2 diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index 5df5f993f..b73f42053 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -26,7 +26,7 @@ #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/util/tflite/tflite_model_loader.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" namespace mediapipe { namespace api2 { @@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf { // Deprecated. Prefers to use "OP_RESOLVER" input side packet instead. // TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the // migration. - static constexpr SideInput::Optional - kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; + static constexpr SideInput:: + Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; static constexpr SideInput::Optional kSideInOpResolver{ "OP_RESOLVER"}; static constexpr SideInput::Optional kSideInModel{"MODEL"}; @@ -112,7 +112,8 @@ class InferenceCalculator : public NodeIntf { protected: using TfLiteDelegatePtr = - std::unique_ptr>; + std::unique_ptr>; static absl::StatusOr> GetModelAsPacket( CalculatorContext* cc); diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index 46804b6fd..7a837111d 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -24,7 +24,7 @@ #include "mediapipe/calculators/tensor/inference_calculator_utils.h" #include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" #include "mediapipe/calculators/tensor/inference_runner.h" -#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/core/shims/cc/interpreter.h" #if defined(MEDIAPIPE_ANDROID) #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #endif // ANDROID diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc index 9ef4e822c..28781e97a 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc @@ -22,9 +22,9 @@ #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/port/ret_check.h" -#include "tensorflow/lite/c/c_api_types.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/core/shims/c/c_api_types.h" +#include "tensorflow/lite/core/shims/cc/interpreter.h" +#include "tensorflow/lite/core/shims/cc/interpreter_builder.h" #include "tensorflow/lite/string_util.h" #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe @@ -33,9 +33,12 @@ namespace mediapipe { namespace { +using Interpreter = ::tflite_shims::Interpreter; +using InterpreterBuilder = ::tflite_shims::InterpreterBuilder; + template void CopyTensorBufferToInterpreter(const Tensor& input_tensor, - tflite::Interpreter* interpreter, + Interpreter* interpreter, int input_tensor_index) { auto input_tensor_view = input_tensor.GetCpuReadView(); auto input_tensor_buffer = input_tensor_view.buffer(); @@ -46,7 +49,7 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor, template <> void CopyTensorBufferToInterpreter(const Tensor& input_tensor, - tflite::Interpreter* interpreter, + Interpreter* interpreter, int input_tensor_index) { const char* input_tensor_buffer = input_tensor.GetCpuReadView().buffer(); @@ -58,7 +61,7 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor, } template -void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, +void CopyTensorBufferFromInterpreter(Interpreter* interpreter, int output_tensor_index, Tensor* output_tensor) { auto output_tensor_view = output_tensor->GetCpuWriteView(); @@ -73,10 +76,9 @@ void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, class InferenceInterpreterDelegateRunner : public InferenceRunner { public: - InferenceInterpreterDelegateRunner( - api2::Packet model, - std::unique_ptr interpreter, - TfLiteDelegatePtr delegate) + InferenceInterpreterDelegateRunner(api2::Packet model, + std::unique_ptr interpreter, + TfLiteDelegatePtr delegate) : model_(std::move(model)), interpreter_(std::move(interpreter)), delegate_(std::move(delegate)) {} @@ -86,7 +88,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner { private: api2::Packet model_; - std::unique_ptr interpreter_; + std::unique_ptr interpreter_; TfLiteDelegatePtr delegate_; }; @@ -197,8 +199,7 @@ CreateInferenceInterpreterDelegateRunner( api2::Packet model, api2::Packet op_resolver, TfLiteDelegatePtr delegate, int interpreter_num_threads) { - tflite::InterpreterBuilder interpreter_builder(*model.Get(), - op_resolver.Get()); + InterpreterBuilder interpreter_builder(*model.Get(), op_resolver.Get()); if (delegate) { interpreter_builder.AddDelegate(delegate.get()); } @@ -207,7 +208,7 @@ CreateInferenceInterpreterDelegateRunner( #else interpreter_builder.SetNumThreads(interpreter_num_threads); #endif // __EMSCRIPTEN__ - std::unique_ptr interpreter; + std::unique_ptr interpreter; RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk); RET_CHECK(interpreter); RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk); diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h index bfe27868e..44d6a932f 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h @@ -23,12 +23,14 @@ #include "mediapipe/framework/api2/packet.h" #include "mediapipe/util/tflite/tflite_model_loader.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/core/shims/c/c_api_types.h" namespace mediapipe { +// TODO: Consider renaming TfLiteDelegatePtr. using TfLiteDelegatePtr = - std::unique_ptr>; + std::unique_ptr>; // Creates inference runner which run inference using newly initialized // interpreter and provided `delegate`. diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index e9b8bfa03..26d73a6a6 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + licenses(["notice"]) package(default_visibility = [ @@ -110,10 +112,13 @@ cc_library( ], ) -cc_library( +cc_library_with_tflite( name = "tflite_model_loader", srcs = ["tflite_model_loader.cc"], hdrs = ["tflite_model_loader.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + ], visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:packet", @@ -121,6 +126,5 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/util:resource_util", - "@org_tensorflow//tensorflow/lite:framework", ], ) diff --git a/mediapipe/util/tflite/tflite_model_loader.cc b/mediapipe/util/tflite/tflite_model_loader.cc index aab94ccbd..86fc260bb 100644 --- a/mediapipe/util/tflite/tflite_model_loader.cc +++ b/mediapipe/util/tflite/tflite_model_loader.cc @@ -19,6 +19,8 @@ namespace mediapipe { +using FlatBufferModel = ::tflite_shims::FlatBufferModel; + absl::StatusOr> TfLiteModelLoader::LoadFromPath( const std::string& path) { std::string model_path = path; @@ -36,12 +38,12 @@ absl::StatusOr> TfLiteModelLoader::LoadFromPath( mediapipe::GetResourceContents(resolved_path, &model_blob)); } - auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( - model_blob.data(), model_blob.size()); + auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(), + model_blob.size()); RET_CHECK(model) << "Failed to load model from path " << model_path; return api2::MakePacket( model.release(), - [model_blob = std::move(model_blob)](tflite::FlatBufferModel* model) { + [model_blob = std::move(model_blob)](FlatBufferModel* model) { // It's required that model_blob is deleted only after // model is deleted, hence capturing model_blob. delete model; diff --git a/mediapipe/util/tflite/tflite_model_loader.h b/mediapipe/util/tflite/tflite_model_loader.h index c1baf128a..8c630ec8d 100644 --- a/mediapipe/util/tflite/tflite_model_loader.h +++ b/mediapipe/util/tflite/tflite_model_loader.h @@ -15,16 +15,20 @@ #ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_ #define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_ +#include +#include +#include + #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/core/shims/cc/model.h" namespace mediapipe { // Represents a TfLite model as a FlatBuffer. using TfLiteModelPtr = - std::unique_ptr>; + std::unique_ptr>; class TfLiteModelLoader { public: