Depends on TFLite shim header.
PiperOrigin-RevId: 508491302
This commit is contained in:
parent
99fc975f49
commit
fd764dae0a
|
@ -21,6 +21,7 @@ load(
|
||||||
)
|
)
|
||||||
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test")
|
||||||
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
|
load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto")
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
@ -370,10 +371,15 @@ mediapipe_proto_library(
|
||||||
# size concerns), depend on those implementations directly, and do not depend on
|
# size concerns), depend on those implementations directly, and do not depend on
|
||||||
# :inference_calculator.
|
# :inference_calculator.
|
||||||
# In all cases, use "InferenceCalulator" in your graphs.
|
# In all cases, use "InferenceCalulator" in your graphs.
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "inference_calculator_interface",
|
name = "inference_calculator_interface",
|
||||||
srcs = ["inference_calculator.cc"],
|
srcs = ["inference_calculator.cc"],
|
||||||
hdrs = ["inference_calculator.h"],
|
hdrs = ["inference_calculator.h"],
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_cc_proto",
|
":inference_calculator_cc_proto",
|
||||||
":inference_calculator_options_lib",
|
":inference_calculator_options_lib",
|
||||||
|
@ -384,12 +390,9 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||||
"//mediapipe/framework/tool:subgraph_expansion",
|
"//mediapipe/framework/tool:subgraph_expansion",
|
||||||
"//mediapipe/util/tflite:tflite_model_loader",
|
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -473,22 +476,24 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "inference_interpreter_delegate_runner",
|
name = "inference_interpreter_delegate_runner",
|
||||||
srcs = ["inference_interpreter_delegate_runner.cc"],
|
srcs = ["inference_interpreter_delegate_runner.cc"],
|
||||||
hdrs = ["inference_interpreter_delegate_runner.h"],
|
hdrs = ["inference_interpreter_delegate_runner.h"],
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_runner",
|
":inference_runner",
|
||||||
"//mediapipe/framework:mediapipe_profiling",
|
"//mediapipe/framework:mediapipe_profiling",
|
||||||
"//mediapipe/framework/api2:packet",
|
"//mediapipe/framework/api2:packet",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/util/tflite:tflite_model_loader",
|
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
|
||||||
"@org_tensorflow//tensorflow/lite:string_util",
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -506,9 +511,9 @@ cc_library(
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
|
||||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
|
"//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"],
|
||||||
|
|
|
@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
|
||||||
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
||||||
}
|
}
|
||||||
return PacketAdopting<tflite::OpResolver>(
|
return PacketAdopting<tflite::OpResolver>(
|
||||||
std::make_unique<
|
std::make_unique<tflite_shims::ops::builtin::
|
||||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
|
BuiltinOpResolverWithoutDefaultDelegates>());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
|
||||||
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
||||||
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
||||||
// migration.
|
// migration.
|
||||||
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
|
||||||
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||||
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
||||||
"OP_RESOLVER"};
|
"OP_RESOLVER"};
|
||||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||||
|
@ -112,7 +112,8 @@ class InferenceCalculator : public NodeIntf {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
using TfLiteDelegatePtr =
|
using TfLiteDelegatePtr =
|
||||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
std::unique_ptr<TfLiteOpaqueDelegate,
|
||||||
|
std::function<void(TfLiteOpaqueDelegate*)>>;
|
||||||
|
|
||||||
static absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
|
static absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
|
||||||
CalculatorContext* cc);
|
CalculatorContext* cc);
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#endif // ANDROID
|
#endif // ANDROID
|
||||||
|
|
|
@ -22,9 +22,9 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "tensorflow/lite/c/c_api_types.h"
|
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||||
#include "tensorflow/lite/interpreter_builder.h"
|
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||||
|
@ -33,9 +33,12 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using Interpreter = ::tflite_shims::Interpreter;
|
||||||
|
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||||
tflite::Interpreter* interpreter,
|
Interpreter* interpreter,
|
||||||
int input_tensor_index) {
|
int input_tensor_index) {
|
||||||
auto input_tensor_view = input_tensor.GetCpuReadView();
|
auto input_tensor_view = input_tensor.GetCpuReadView();
|
||||||
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
||||||
|
@ -46,7 +49,7 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||||
tflite::Interpreter* interpreter,
|
Interpreter* interpreter,
|
||||||
int input_tensor_index) {
|
int input_tensor_index) {
|
||||||
const char* input_tensor_buffer =
|
const char* input_tensor_buffer =
|
||||||
input_tensor.GetCpuReadView().buffer<char>();
|
input_tensor.GetCpuReadView().buffer<char>();
|
||||||
|
@ -58,7 +61,7 @@ void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
void CopyTensorBufferFromInterpreter(Interpreter* interpreter,
|
||||||
int output_tensor_index,
|
int output_tensor_index,
|
||||||
Tensor* output_tensor) {
|
Tensor* output_tensor) {
|
||||||
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
auto output_tensor_view = output_tensor->GetCpuWriteView();
|
||||||
|
@ -73,9 +76,8 @@ void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||||
|
|
||||||
class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
||||||
public:
|
public:
|
||||||
InferenceInterpreterDelegateRunner(
|
InferenceInterpreterDelegateRunner(api2::Packet<TfLiteModelPtr> model,
|
||||||
api2::Packet<TfLiteModelPtr> model,
|
std::unique_ptr<Interpreter> interpreter,
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter,
|
|
||||||
TfLiteDelegatePtr delegate)
|
TfLiteDelegatePtr delegate)
|
||||||
: model_(std::move(model)),
|
: model_(std::move(model)),
|
||||||
interpreter_(std::move(interpreter)),
|
interpreter_(std::move(interpreter)),
|
||||||
|
@ -86,7 +88,7 @@ class InferenceInterpreterDelegateRunner : public InferenceRunner {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
api2::Packet<TfLiteModelPtr> model_;
|
api2::Packet<TfLiteModelPtr> model_;
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
std::unique_ptr<Interpreter> interpreter_;
|
||||||
TfLiteDelegatePtr delegate_;
|
TfLiteDelegatePtr delegate_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -197,8 +199,7 @@ CreateInferenceInterpreterDelegateRunner(
|
||||||
api2::Packet<TfLiteModelPtr> model,
|
api2::Packet<TfLiteModelPtr> model,
|
||||||
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
api2::Packet<tflite::OpResolver> op_resolver, TfLiteDelegatePtr delegate,
|
||||||
int interpreter_num_threads) {
|
int interpreter_num_threads) {
|
||||||
tflite::InterpreterBuilder interpreter_builder(*model.Get(),
|
InterpreterBuilder interpreter_builder(*model.Get(), op_resolver.Get());
|
||||||
op_resolver.Get());
|
|
||||||
if (delegate) {
|
if (delegate) {
|
||||||
interpreter_builder.AddDelegate(delegate.get());
|
interpreter_builder.AddDelegate(delegate.get());
|
||||||
}
|
}
|
||||||
|
@ -207,7 +208,7 @@ CreateInferenceInterpreterDelegateRunner(
|
||||||
#else
|
#else
|
||||||
interpreter_builder.SetNumThreads(interpreter_num_threads);
|
interpreter_builder.SetNumThreads(interpreter_num_threads);
|
||||||
#endif // __EMSCRIPTEN__
|
#endif // __EMSCRIPTEN__
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
|
RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk);
|
||||||
RET_CHECK(interpreter);
|
RET_CHECK(interpreter);
|
||||||
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||||
|
|
|
@ -23,12 +23,14 @@
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
// TODO: Consider renaming TfLiteDelegatePtr.
|
||||||
using TfLiteDelegatePtr =
|
using TfLiteDelegatePtr =
|
||||||
std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>;
|
std::unique_ptr<TfLiteOpaqueDelegate,
|
||||||
|
std::function<void(TfLiteOpaqueDelegate*)>>;
|
||||||
|
|
||||||
// Creates inference runner which run inference using newly initialized
|
// Creates inference runner which run inference using newly initialized
|
||||||
// interpreter and provided `delegate`.
|
// interpreter and provided `delegate`.
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
package(default_visibility = [
|
package(default_visibility = [
|
||||||
|
@ -110,10 +112,13 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_tflite(
|
||||||
name = "tflite_model_loader",
|
name = "tflite_model_loader",
|
||||||
srcs = ["tflite_model_loader.cc"],
|
srcs = ["tflite_model_loader.cc"],
|
||||||
hdrs = ["tflite_model_loader.h"],
|
hdrs = ["tflite_model_loader.h"],
|
||||||
|
tflite_deps = [
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||||
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/api2:packet",
|
"//mediapipe/framework/api2:packet",
|
||||||
|
@ -121,6 +126,5 @@ cc_library(
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
"//mediapipe/util:resource_util",
|
"//mediapipe/util:resource_util",
|
||||||
"@org_tensorflow//tensorflow/lite:framework",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using FlatBufferModel = ::tflite_shims::FlatBufferModel;
|
||||||
|
|
||||||
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
||||||
const std::string& path) {
|
const std::string& path) {
|
||||||
std::string model_path = path;
|
std::string model_path = path;
|
||||||
|
@ -36,12 +38,12 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
|
||||||
mediapipe::GetResourceContents(resolved_path, &model_blob));
|
mediapipe::GetResourceContents(resolved_path, &model_blob));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
|
auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(),
|
||||||
model_blob.data(), model_blob.size());
|
model_blob.size());
|
||||||
RET_CHECK(model) << "Failed to load model from path " << model_path;
|
RET_CHECK(model) << "Failed to load model from path " << model_path;
|
||||||
return api2::MakePacket<TfLiteModelPtr>(
|
return api2::MakePacket<TfLiteModelPtr>(
|
||||||
model.release(),
|
model.release(),
|
||||||
[model_blob = std::move(model_blob)](tflite::FlatBufferModel* model) {
|
[model_blob = std::move(model_blob)](FlatBufferModel* model) {
|
||||||
// It's required that model_blob is deleted only after
|
// It's required that model_blob is deleted only after
|
||||||
// model is deleted, hence capturing model_blob.
|
// model is deleted, hence capturing model_blob.
|
||||||
delete model;
|
delete model;
|
||||||
|
|
|
@ -15,16 +15,20 @@
|
||||||
#ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
|
#ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
|
||||||
#define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
|
#define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/core/shims/cc/model.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
// Represents a TfLite model as a FlatBuffer.
|
// Represents a TfLite model as a FlatBuffer.
|
||||||
using TfLiteModelPtr =
|
using TfLiteModelPtr =
|
||||||
std::unique_ptr<tflite::FlatBufferModel,
|
std::unique_ptr<tflite_shims::FlatBufferModel,
|
||||||
std::function<void(tflite::FlatBufferModel*)>>;
|
std::function<void(tflite_shims::FlatBufferModel*)>>;
|
||||||
|
|
||||||
class TfLiteModelLoader {
|
class TfLiteModelLoader {
|
||||||
public:
|
public:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user