Depends on TFLite shim header.

PiperOrigin-RevId: 508491302
This commit is contained in:
MediaPipe Team 2023-02-09 15:25:20 -08:00 committed by Copybara-Service
parent 99fc975f49
commit fd764dae0a
9 changed files with 60 additions and 41 deletions

View File

@ -21,6 +21,7 @@ load(
)
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
licenses(["notice"])
@ -370,10 +371,15 @@ mediapipe_proto_library(
# size concerns), depend on those implementations directly, and do not depend on
# :inference_calculator.
# In all cases, use "InferenceCalulator" in your graphs.
cc_library(
cc_library_with_tflite(
name = "inference_calculator_interface",
srcs = ["inference_calculator.cc"],
hdrs = ["inference_calculator.h"],
tflite_deps = [
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
":inference_calculator_cc_proto",
":inference_calculator_options_lib",
@ -384,12 +390,9 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:subgraph_expansion",
"//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)
@ -473,22 +476,24 @@ cc_library(
],
)
cc_library(
cc_library_with_tflite(
name = "inference_interpreter_delegate_runner",
srcs = ["inference_interpreter_delegate_runner.cc"],
hdrs = ["inference_interpreter_delegate_runner.h"],
tflite_deps = [
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
],
deps = [
":inference_runner",
"//mediapipe/framework:mediapipe_profiling",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
@ -506,9 +511,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
] + select({
"//conditions:default": [],
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],

View File

@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
}
return PacketAdopting<tflite::OpResolver>(
std::make_unique<
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
std::make_unique<tflite_shims::ops::builtin::
BuiltinOpResolverWithoutDefaultDelegates>());
}
} // namespace api2

View File

@ -26,7 +26,7 @@
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace mediapipe {
namespace api2 {
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
// migration.
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
"OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
@ -112,7 +112,8 @@ class InferenceCalculator : public NodeIntf {
protected:
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
std::unique_ptr<TfLiteOpaqueDelegate,
std::function<void(TfLiteOpaqueDelegate*)>>;
static absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
CalculatorContext* cc);

View File

@ -24,7 +24,7 @@
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
#include "mediapipe/calculators/tensor/inference_runner.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/core/shims/cc/interpreter.h"
#if defined(MEDIAPIPE_ANDROID)
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#endif // ANDROID

View File

@ -22,9 +22,9 @@
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/core/shims/c/c_api_types.h"
#include "tensorflow/lite/core/shims/cc/interpreter.h"
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
@ -33,9 +33,12 @@ namespace mediapipe {
namespace {
using Interpreter = ::tflite_shims::Interpreter;
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
template <typename T>
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
Interpreter* interpreter,
int input_tensor_index) {
auto input_tensor_view = input_tensor.GetCpuReadView();
auto input_tensor_buffer = input_tensor_view.buffer<T>();
@ -46,7 +49,7 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
@ -58,7 +61,7 @@ void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
}
template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
void CopyTensorBufferFromInterpreter(Interpreter* interpreter,
int output_tensor_index,
Tensor* output_tensor) {
auto output_tensor_view = output_tensor->GetCpuWriteView();
@ -73,9 +76,8 @@ void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
class InferenceInterpreterDelegateRunner : public InferenceRunner {
public:
InferenceInterpreterDelegateRunner(
api2::Packet<TfLiteModelPtr> model,
std::unique_ptr<tflite::Interpreter> interpreter,
InferenceInterpreterDelegateRunner(api2::Packet<TfLiteModelPtr> model,
std::unique_ptr<Interpreter> interpreter,
TfLiteDelegatePtr delegate)
: model_(std::move(model)),
interpreter_(std::move(interpreter)),
@ -86,7 +88,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
private:
api2::Packet<TfLiteModelPtr> model_;
std::unique_ptr<tflite::Interpreter> interpreter_;
std::unique_ptr<Interpreter> interpreter_;
TfLiteDelegatePtr delegate_;
};
@ -197,8 +199,7 @@ CreateInferenceInterpreterDelegateRunner(
api2::Packet<TfLiteModelPtr> model,
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
int interpreter_num_threads) {
tflite::InterpreterBuilder interpreter_builder(*model.Get(),
op_resolver.Get());
InterpreterBuilder interpreter_builder(*model.Get(), op_resolver.Get());
if (delegate) {
interpreter_builder.AddDelegate(delegate.get());
}
@ -207,7 +208,7 @@ CreateInferenceInterpreterDelegateRunner(
#else
interpreter_builder.SetNumThreads(interpreter_num_threads);
#endif // __EMSCRIPTEN__
std::unique_ptr<tflite::Interpreter> interpreter;
std::unique_ptr<Interpreter> interpreter;
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
RET_CHECK(interpreter);
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);

View File

@ -23,12 +23,14 @@
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/core/shims/c/c_api_types.h"
namespace mediapipe {
// TODO: Consider renaming TfLiteDelegatePtr.
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
std::unique_ptr<TfLiteOpaqueDelegate,
std::function<void(TfLiteOpaqueDelegate*)>>;
// Creates inference runner which run inference using newly initialized
// interpreter and provided `delegate`.

View File

@ -13,6 +13,8 @@
# limitations under the License.
#
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
licenses(["notice"])
package(default_visibility = [
@ -110,10 +112,13 @@ cc_library(
],
)
cc_library(
cc_library_with_tflite(
name = "tflite_model_loader",
srcs = ["tflite_model_loader.cc"],
hdrs = ["tflite_model_loader.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/api2:packet",
@ -121,6 +126,5 @@ cc_library(
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework",
],
)

View File

@ -19,6 +19,8 @@
namespace mediapipe {
using FlatBufferModel = ::tflite_shims::FlatBufferModel;
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
const std::string& path) {
std::string model_path = path;
@ -36,12 +38,12 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
mediapipe::GetResourceContents(resolved_path, &model_blob));
}
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
model_blob.data(), model_blob.size());
auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(),
model_blob.size());
RET_CHECK(model) << "Failed to load model from path " << model_path;
return api2::MakePacket<TfLiteModelPtr>(
model.release(),
[model_blob = std::move(model_blob)](tflite::FlatBufferModel* model) {
[model_blob = std::move(model_blob)](FlatBufferModel* model) {
// It's required that model_blob is deleted only after
// model is deleted, hence capturing model_blob.
delete model;

View File

@ -15,16 +15,20 @@
#ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
#define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
#include <functional>
#include <memory>
#include <string>
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/core/shims/cc/model.h"
namespace mediapipe {
// Represents a TfLite model as a FlatBuffer.
using TfLiteModelPtr =
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>;
std::unique_ptr<tflite_shims::FlatBufferModel,
std::function<void(tflite_shims::FlatBufferModel*)>>;
class TfLiteModelLoader {
public: