From 709eb812cc2f593067f7fdc23a481390a70688e3 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Tue, 6 Jun 2023 09:58:18 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 538215311 --- mediapipe/util/tflite/BUILD | 9 ++-- mediapipe/util/tflite/op_resolver.cc | 81 ++++++++++++++++++++-------- 2 files changed, 65 insertions(+), 25 deletions(-) diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 59663c9ba..f31c23696 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -67,17 +67,20 @@ cc_library( ], ) -# TODO: Re-evaluate which of these libraries we can avoid making -# cc_library_with_tflite and can be changed back to cc_library. +# This target has an implementation dependency on TFLite/TFLite-in-GMSCore, +# but it does not have any API dependency on TFLite-in-GMSCore. cc_library_with_tflite( name = "op_resolver", srcs = ["op_resolver.cc"], hdrs = ["op_resolver.h"], tflite_deps = [ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/c:c_api", + "@org_tensorflow//tensorflow/lite/c:c_api_experimental", # For c_api_opaque.h + "@org_tensorflow//tensorflow/lite/c:common", # For builtin_op_data.h ], deps = [ - "@org_tensorflow//tensorflow/lite:builtin_op_data", + "@org_tensorflow//tensorflow/lite:builtin_ops", ], ) diff --git a/mediapipe/util/tflite/op_resolver.cc b/mediapipe/util/tflite/op_resolver.cc index 23f066666..44eff4566 100644 --- a/mediapipe/util/tflite/op_resolver.cc +++ b/mediapipe/util/tflite/op_resolver.cc @@ -14,47 +14,84 @@ #include "mediapipe/util/tflite/op_resolver.h" -#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api.h" +#include "tensorflow/lite/c/c_api_opaque.h" namespace mediapipe { namespace { +constexpr char kMaxPoolingWithArgmax2DOpName[] = "MaxPoolingWithArgmax2D"; +constexpr int kMaxPoolingWithArgmax2DOpVersion = 1; + +constexpr char kMaxUnpooling2DOpName[] = "MaxUnpooling2D"; +constexpr int kMaxUnpooling2DOpVersion = 1; + +constexpr char kConvolution2DTransposeBiasOpName[] = + "Convolution2DTransposeBias"; +constexpr int kConvolution2DTransposeBiasOpVersion = 1; + TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() { - static TfLiteRegistration reg = { - [](TfLiteContext*, const char*, size_t) -> void* { - return new TfLitePaddingValues(); - }, - [](TfLiteContext*, void* buffer) -> void { - delete reinterpret_cast(buffer); - }, - [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { - return kTfLiteOk; - }, - [](TfLiteContext* context, TfLiteNode*) -> TfLiteStatus { - context->ReportError( - context, "MaxPoolingWithArgmax2D is only available on the GPU."); - return kTfLiteError; - }, - }; + static TfLiteRegistrationExternal* reg_external = []() { + // Intentionally allocated and never destroyed. + auto* r = TfLiteRegistrationExternalCreate( + kTfLiteBuiltinCustom, kMaxPoolingWithArgmax2DOpName, + kMaxPoolingWithArgmax2DOpVersion); + TfLiteRegistrationExternalSetInit( + r, [](TfLiteOpaqueContext*, const char*, size_t) -> void* { + return new TfLitePaddingValues(); + }); + TfLiteRegistrationExternalSetFree( + r, [](TfLiteOpaqueContext*, void* buffer) -> void { + delete reinterpret_cast(buffer); + }); + TfLiteRegistrationExternalSetPrepare( + r, + [](TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; }); + TfLiteRegistrationExternalSetInvoke( + r, [](TfLiteOpaqueContext* context, TfLiteOpaqueNode*) -> TfLiteStatus { + TfLiteOpaqueContextReportError( + context, "MaxPoolingWithArgmax2D is only available on the GPU."); + return kTfLiteError; + }); + return r; + }(); + static TfLiteRegistration reg = {.registration_external = reg_external}; return ® } TfLiteRegistration* RegisterMaxUnpooling2D() { - static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + static TfLiteRegistrationExternal* reg_external = + // Intentionally allocated and never destroyed. + TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom, + kMaxUnpooling2DOpName, + kMaxUnpooling2DOpVersion); + static TfLiteRegistration reg = {.registration_external = reg_external}; return ® } TfLiteRegistration* RegisterConvolution2DTransposeBias() { - static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + static TfLiteRegistrationExternal* reg_external = + // Intentionally allocated and never destroyed. + TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom, + kConvolution2DTransposeBiasOpName, + kConvolution2DTransposeBiasOpVersion); + static TfLiteRegistration reg = {.registration_external = reg_external}; return ® } } // namespace OpResolver::OpResolver() { - AddCustom("MaxPoolingWithArgmax2D", RegisterMaxPoolingWithArgmax2D()); - AddCustom("MaxUnpooling2D", RegisterMaxUnpooling2D()); - AddCustom("Convolution2DTransposeBias", RegisterConvolution2DTransposeBias()); + AddCustom(kMaxPoolingWithArgmax2DOpName, RegisterMaxPoolingWithArgmax2D(), + kMaxPoolingWithArgmax2DOpVersion); + AddCustom(kMaxUnpooling2DOpName, RegisterMaxUnpooling2D(), + kMaxUnpooling2DOpVersion); + AddCustom(kConvolution2DTransposeBiasOpName, + RegisterConvolution2DTransposeBias(), + kConvolution2DTransposeBiasOpVersion); } } // namespace mediapipe