Internal change

PiperOrigin-RevId: 538215311
This commit is contained in:
Fergus Henderson 2023-06-06 09:58:18 -07:00 committed by Copybara-Service
parent 37290f0224
commit 709eb812cc
2 changed files with 65 additions and 25 deletions

View File

@ -67,17 +67,20 @@ cc_library(
],
)
# TODO: Re-evaluate which of these libraries we can avoid making
# cc_library_with_tflite and can be changed back to cc_library.
# This target has an implementation dependency on TFLite/TFLite-in-GMSCore,
# but it does not have any API dependency on TFLite-in-GMSCore.
cc_library_with_tflite(
name = "op_resolver",
srcs = ["op_resolver.cc"],
hdrs = ["op_resolver.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/c:c_api",
"@org_tensorflow//tensorflow/lite/c:c_api_experimental", # For c_api_opaque.h
"@org_tensorflow//tensorflow/lite/c:common", # For builtin_op_data.h
],
deps = [
"@org_tensorflow//tensorflow/lite:builtin_op_data",
"@org_tensorflow//tensorflow/lite:builtin_ops",
],
)

View File

@ -14,47 +14,84 @@
#include "mediapipe/util/tflite/op_resolver.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_opaque.h"
namespace mediapipe {
namespace {
constexpr char kMaxPoolingWithArgmax2DOpName[] = "MaxPoolingWithArgmax2D";
constexpr int kMaxPoolingWithArgmax2DOpVersion = 1;
constexpr char kMaxUnpooling2DOpName[] = "MaxUnpooling2D";
constexpr int kMaxUnpooling2DOpVersion = 1;
constexpr char kConvolution2DTransposeBiasOpName[] =
"Convolution2DTransposeBias";
constexpr int kConvolution2DTransposeBiasOpVersion = 1;
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
static TfLiteRegistration reg = {
[](TfLiteContext*, const char*, size_t) -> void* {
return new TfLitePaddingValues();
},
[](TfLiteContext*, void* buffer) -> void {
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
},
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
return kTfLiteOk;
},
[](TfLiteContext* context, TfLiteNode*) -> TfLiteStatus {
context->ReportError(
context, "MaxPoolingWithArgmax2D is only available on the GPU.");
return kTfLiteError;
},
};
static TfLiteRegistrationExternal* reg_external = []() {
// Intentionally allocated and never destroyed.
auto* r = TfLiteRegistrationExternalCreate(
kTfLiteBuiltinCustom, kMaxPoolingWithArgmax2DOpName,
kMaxPoolingWithArgmax2DOpVersion);
TfLiteRegistrationExternalSetInit(
r, [](TfLiteOpaqueContext*, const char*, size_t) -> void* {
return new TfLitePaddingValues();
});
TfLiteRegistrationExternalSetFree(
r, [](TfLiteOpaqueContext*, void* buffer) -> void {
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
});
TfLiteRegistrationExternalSetPrepare(
r,
[](TfLiteOpaqueContext* context,
TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; });
TfLiteRegistrationExternalSetInvoke(
r, [](TfLiteOpaqueContext* context, TfLiteOpaqueNode*) -> TfLiteStatus {
TfLiteOpaqueContextReportError(
context, "MaxPoolingWithArgmax2D is only available on the GPU.");
return kTfLiteError;
});
return r;
}();
static TfLiteRegistration reg = {.registration_external = reg_external};
return &reg;
}
TfLiteRegistration* RegisterMaxUnpooling2D() {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
static TfLiteRegistrationExternal* reg_external =
// Intentionally allocated and never destroyed.
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
kMaxUnpooling2DOpName,
kMaxUnpooling2DOpVersion);
static TfLiteRegistration reg = {.registration_external = reg_external};
return &reg;
}
TfLiteRegistration* RegisterConvolution2DTransposeBias() {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
static TfLiteRegistrationExternal* reg_external =
// Intentionally allocated and never destroyed.
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
kConvolution2DTransposeBiasOpName,
kConvolution2DTransposeBiasOpVersion);
static TfLiteRegistration reg = {.registration_external = reg_external};
return &reg;
}
} // namespace
OpResolver::OpResolver() {
AddCustom("MaxPoolingWithArgmax2D", RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D", RegisterMaxUnpooling2D());
AddCustom("Convolution2DTransposeBias", RegisterConvolution2DTransposeBias());
AddCustom(kMaxPoolingWithArgmax2DOpName, RegisterMaxPoolingWithArgmax2D(),
kMaxPoolingWithArgmax2DOpVersion);
AddCustom(kMaxUnpooling2DOpName, RegisterMaxUnpooling2D(),
kMaxUnpooling2DOpVersion);
AddCustom(kConvolution2DTransposeBiasOpName,
RegisterConvolution2DTransposeBias(),
kConvolution2DTransposeBiasOpVersion);
}
} // namespace mediapipe