Internal change

PiperOrigin-RevId: 538215311
This commit is contained in:
Fergus Henderson 2023-06-06 09:58:18 -07:00 committed by Copybara-Service
parent 37290f0224
commit 709eb812cc
2 changed files with 65 additions and 25 deletions

View File

@ -67,17 +67,20 @@ cc_library(
], ],
) )
# TODO: Re-evaluate which of these libraries we can avoid making # This target has an implementation dependency on TFLite/TFLite-in-GMSCore,
# cc_library_with_tflite and can be changed back to cc_library. # but it does not have any API dependency on TFLite-in-GMSCore.
cc_library_with_tflite( cc_library_with_tflite(
name = "op_resolver", name = "op_resolver",
srcs = ["op_resolver.cc"], srcs = ["op_resolver.cc"],
hdrs = ["op_resolver.h"], hdrs = ["op_resolver.h"],
tflite_deps = [ tflite_deps = [
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/c:c_api",
"@org_tensorflow//tensorflow/lite/c:c_api_experimental", # For c_api_opaque.h
"@org_tensorflow//tensorflow/lite/c:common", # For builtin_op_data.h
], ],
deps = [ deps = [
"@org_tensorflow//tensorflow/lite:builtin_op_data", "@org_tensorflow//tensorflow/lite:builtin_ops",
], ],
) )

View File

@ -14,47 +14,84 @@
#include "mediapipe/util/tflite/op_resolver.h" #include "mediapipe/util/tflite/op_resolver.h"
#include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_opaque.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kMaxPoolingWithArgmax2DOpName[] = "MaxPoolingWithArgmax2D";
constexpr int kMaxPoolingWithArgmax2DOpVersion = 1;
constexpr char kMaxUnpooling2DOpName[] = "MaxUnpooling2D";
constexpr int kMaxUnpooling2DOpVersion = 1;
constexpr char kConvolution2DTransposeBiasOpName[] =
"Convolution2DTransposeBias";
constexpr int kConvolution2DTransposeBiasOpVersion = 1;
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() { TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
static TfLiteRegistration reg = { static TfLiteRegistrationExternal* reg_external = []() {
[](TfLiteContext*, const char*, size_t) -> void* { // Intentionally allocated and never destroyed.
auto* r = TfLiteRegistrationExternalCreate(
kTfLiteBuiltinCustom, kMaxPoolingWithArgmax2DOpName,
kMaxPoolingWithArgmax2DOpVersion);
TfLiteRegistrationExternalSetInit(
r, [](TfLiteOpaqueContext*, const char*, size_t) -> void* {
return new TfLitePaddingValues(); return new TfLitePaddingValues();
}, });
[](TfLiteContext*, void* buffer) -> void { TfLiteRegistrationExternalSetFree(
r, [](TfLiteOpaqueContext*, void* buffer) -> void {
delete reinterpret_cast<TfLitePaddingValues*>(buffer); delete reinterpret_cast<TfLitePaddingValues*>(buffer);
}, });
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { TfLiteRegistrationExternalSetPrepare(
return kTfLiteOk; r,
}, [](TfLiteOpaqueContext* context,
[](TfLiteContext* context, TfLiteNode*) -> TfLiteStatus { TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; });
context->ReportError( TfLiteRegistrationExternalSetInvoke(
r, [](TfLiteOpaqueContext* context, TfLiteOpaqueNode*) -> TfLiteStatus {
TfLiteOpaqueContextReportError(
context, "MaxPoolingWithArgmax2D is only available on the GPU."); context, "MaxPoolingWithArgmax2D is only available on the GPU.");
return kTfLiteError; return kTfLiteError;
}, });
}; return r;
}();
static TfLiteRegistration reg = {.registration_external = reg_external};
return &reg; return &reg;
} }
TfLiteRegistration* RegisterMaxUnpooling2D() { TfLiteRegistration* RegisterMaxUnpooling2D() {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; static TfLiteRegistrationExternal* reg_external =
// Intentionally allocated and never destroyed.
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
kMaxUnpooling2DOpName,
kMaxUnpooling2DOpVersion);
static TfLiteRegistration reg = {.registration_external = reg_external};
return &reg; return &reg;
} }
TfLiteRegistration* RegisterConvolution2DTransposeBias() { TfLiteRegistration* RegisterConvolution2DTransposeBias() {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; static TfLiteRegistrationExternal* reg_external =
// Intentionally allocated and never destroyed.
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
kConvolution2DTransposeBiasOpName,
kConvolution2DTransposeBiasOpVersion);
static TfLiteRegistration reg = {.registration_external = reg_external};
return &reg; return &reg;
} }
} // namespace } // namespace
OpResolver::OpResolver() { OpResolver::OpResolver() {
AddCustom("MaxPoolingWithArgmax2D", RegisterMaxPoolingWithArgmax2D()); AddCustom(kMaxPoolingWithArgmax2DOpName, RegisterMaxPoolingWithArgmax2D(),
AddCustom("MaxUnpooling2D", RegisterMaxUnpooling2D()); kMaxPoolingWithArgmax2DOpVersion);
AddCustom("Convolution2DTransposeBias", RegisterConvolution2DTransposeBias()); AddCustom(kMaxUnpooling2DOpName, RegisterMaxUnpooling2D(),
kMaxUnpooling2DOpVersion);
AddCustom(kConvolution2DTransposeBiasOpName,
RegisterConvolution2DTransposeBias(),
kConvolution2DTransposeBiasOpVersion);
} }
} // namespace mediapipe } // namespace mediapipe