Internal change
PiperOrigin-RevId: 538215311
This commit is contained in:
parent
37290f0224
commit
709eb812cc
|
@ -67,17 +67,20 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: Re-evaluate which of these libraries we can avoid making
|
||||
# cc_library_with_tflite and can be changed back to cc_library.
|
||||
# This target has an implementation dependency on TFLite/TFLite-in-GMSCore,
|
||||
# but it does not have any API dependency on TFLite-in-GMSCore.
|
||||
cc_library_with_tflite(
|
||||
name = "op_resolver",
|
||||
srcs = ["op_resolver.cc"],
|
||||
hdrs = ["op_resolver.h"],
|
||||
tflite_deps = [
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_experimental", # For c_api_opaque.h
|
||||
"@org_tensorflow//tensorflow/lite/c:common", # For builtin_op_data.h
|
||||
],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/lite:builtin_op_data",
|
||||
"@org_tensorflow//tensorflow/lite:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -14,47 +14,84 @@
|
|||
|
||||
#include "mediapipe/util/tflite/op_resolver.h"
|
||||
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api.h"
|
||||
#include "tensorflow/lite/c/c_api_opaque.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kMaxPoolingWithArgmax2DOpName[] = "MaxPoolingWithArgmax2D";
|
||||
constexpr int kMaxPoolingWithArgmax2DOpVersion = 1;
|
||||
|
||||
constexpr char kMaxUnpooling2DOpName[] = "MaxUnpooling2D";
|
||||
constexpr int kMaxUnpooling2DOpVersion = 1;
|
||||
|
||||
constexpr char kConvolution2DTransposeBiasOpName[] =
|
||||
"Convolution2DTransposeBias";
|
||||
constexpr int kConvolution2DTransposeBiasOpVersion = 1;
|
||||
|
||||
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
|
||||
static TfLiteRegistration reg = {
|
||||
[](TfLiteContext*, const char*, size_t) -> void* {
|
||||
static TfLiteRegistrationExternal* reg_external = []() {
|
||||
// Intentionally allocated and never destroyed.
|
||||
auto* r = TfLiteRegistrationExternalCreate(
|
||||
kTfLiteBuiltinCustom, kMaxPoolingWithArgmax2DOpName,
|
||||
kMaxPoolingWithArgmax2DOpVersion);
|
||||
TfLiteRegistrationExternalSetInit(
|
||||
r, [](TfLiteOpaqueContext*, const char*, size_t) -> void* {
|
||||
return new TfLitePaddingValues();
|
||||
},
|
||||
[](TfLiteContext*, void* buffer) -> void {
|
||||
});
|
||||
TfLiteRegistrationExternalSetFree(
|
||||
r, [](TfLiteOpaqueContext*, void* buffer) -> void {
|
||||
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
|
||||
},
|
||||
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
return kTfLiteOk;
|
||||
},
|
||||
[](TfLiteContext* context, TfLiteNode*) -> TfLiteStatus {
|
||||
context->ReportError(
|
||||
});
|
||||
TfLiteRegistrationExternalSetPrepare(
|
||||
r,
|
||||
[](TfLiteOpaqueContext* context,
|
||||
TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; });
|
||||
TfLiteRegistrationExternalSetInvoke(
|
||||
r, [](TfLiteOpaqueContext* context, TfLiteOpaqueNode*) -> TfLiteStatus {
|
||||
TfLiteOpaqueContextReportError(
|
||||
context, "MaxPoolingWithArgmax2D is only available on the GPU.");
|
||||
return kTfLiteError;
|
||||
},
|
||||
};
|
||||
});
|
||||
return r;
|
||||
}();
|
||||
static TfLiteRegistration reg = {.registration_external = reg_external};
|
||||
return ®
|
||||
}
|
||||
|
||||
TfLiteRegistration* RegisterMaxUnpooling2D() {
|
||||
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
|
||||
static TfLiteRegistrationExternal* reg_external =
|
||||
// Intentionally allocated and never destroyed.
|
||||
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
|
||||
kMaxUnpooling2DOpName,
|
||||
kMaxUnpooling2DOpVersion);
|
||||
static TfLiteRegistration reg = {.registration_external = reg_external};
|
||||
return ®
|
||||
}
|
||||
|
||||
TfLiteRegistration* RegisterConvolution2DTransposeBias() {
|
||||
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
|
||||
static TfLiteRegistrationExternal* reg_external =
|
||||
// Intentionally allocated and never destroyed.
|
||||
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
|
||||
kConvolution2DTransposeBiasOpName,
|
||||
kConvolution2DTransposeBiasOpVersion);
|
||||
static TfLiteRegistration reg = {.registration_external = reg_external};
|
||||
return ®
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OpResolver::OpResolver() {
|
||||
AddCustom("MaxPoolingWithArgmax2D", RegisterMaxPoolingWithArgmax2D());
|
||||
AddCustom("MaxUnpooling2D", RegisterMaxUnpooling2D());
|
||||
AddCustom("Convolution2DTransposeBias", RegisterConvolution2DTransposeBias());
|
||||
AddCustom(kMaxPoolingWithArgmax2DOpName, RegisterMaxPoolingWithArgmax2D(),
|
||||
kMaxPoolingWithArgmax2DOpVersion);
|
||||
AddCustom(kMaxUnpooling2DOpName, RegisterMaxUnpooling2D(),
|
||||
kMaxUnpooling2DOpVersion);
|
||||
AddCustom(kConvolution2DTransposeBiasOpName,
|
||||
RegisterConvolution2DTransposeBias(),
|
||||
kConvolution2DTransposeBiasOpVersion);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
Loading…
Reference in New Issue
Block a user