Internal change
PiperOrigin-RevId: 538215311
This commit is contained in:
parent
37290f0224
commit
709eb812cc
|
@ -67,17 +67,20 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Re-evaluate which of these libraries we can avoid making
|
# This target has an implementation dependency on TFLite/TFLite-in-GMSCore,
|
||||||
# cc_library_with_tflite and can be changed back to cc_library.
|
# but it does not have any API dependency on TFLite-in-GMSCore.
|
||||||
cc_library_with_tflite(
|
cc_library_with_tflite(
|
||||||
name = "op_resolver",
|
name = "op_resolver",
|
||||||
srcs = ["op_resolver.cc"],
|
srcs = ["op_resolver.cc"],
|
||||||
hdrs = ["op_resolver.h"],
|
hdrs = ["op_resolver.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:c_api",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:c_api_experimental", # For c_api_opaque.h
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common", # For builtin_op_data.h
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@org_tensorflow//tensorflow/lite:builtin_op_data",
|
"@org_tensorflow//tensorflow/lite:builtin_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,47 +14,84 @@
|
||||||
|
|
||||||
#include "mediapipe/util/tflite/op_resolver.h"
|
#include "mediapipe/util/tflite/op_resolver.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/builtin_op_data.h"
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/c_api.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_opaque.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kMaxPoolingWithArgmax2DOpName[] = "MaxPoolingWithArgmax2D";
|
||||||
|
constexpr int kMaxPoolingWithArgmax2DOpVersion = 1;
|
||||||
|
|
||||||
|
constexpr char kMaxUnpooling2DOpName[] = "MaxUnpooling2D";
|
||||||
|
constexpr int kMaxUnpooling2DOpVersion = 1;
|
||||||
|
|
||||||
|
constexpr char kConvolution2DTransposeBiasOpName[] =
|
||||||
|
"Convolution2DTransposeBias";
|
||||||
|
constexpr int kConvolution2DTransposeBiasOpVersion = 1;
|
||||||
|
|
||||||
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
|
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
|
||||||
static TfLiteRegistration reg = {
|
static TfLiteRegistrationExternal* reg_external = []() {
|
||||||
[](TfLiteContext*, const char*, size_t) -> void* {
|
// Intentionally allocated and never destroyed.
|
||||||
return new TfLitePaddingValues();
|
auto* r = TfLiteRegistrationExternalCreate(
|
||||||
},
|
kTfLiteBuiltinCustom, kMaxPoolingWithArgmax2DOpName,
|
||||||
[](TfLiteContext*, void* buffer) -> void {
|
kMaxPoolingWithArgmax2DOpVersion);
|
||||||
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
|
TfLiteRegistrationExternalSetInit(
|
||||||
},
|
r, [](TfLiteOpaqueContext*, const char*, size_t) -> void* {
|
||||||
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
return new TfLitePaddingValues();
|
||||||
return kTfLiteOk;
|
});
|
||||||
},
|
TfLiteRegistrationExternalSetFree(
|
||||||
[](TfLiteContext* context, TfLiteNode*) -> TfLiteStatus {
|
r, [](TfLiteOpaqueContext*, void* buffer) -> void {
|
||||||
context->ReportError(
|
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
|
||||||
context, "MaxPoolingWithArgmax2D is only available on the GPU.");
|
});
|
||||||
return kTfLiteError;
|
TfLiteRegistrationExternalSetPrepare(
|
||||||
},
|
r,
|
||||||
};
|
[](TfLiteOpaqueContext* context,
|
||||||
|
TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; });
|
||||||
|
TfLiteRegistrationExternalSetInvoke(
|
||||||
|
r, [](TfLiteOpaqueContext* context, TfLiteOpaqueNode*) -> TfLiteStatus {
|
||||||
|
TfLiteOpaqueContextReportError(
|
||||||
|
context, "MaxPoolingWithArgmax2D is only available on the GPU.");
|
||||||
|
return kTfLiteError;
|
||||||
|
});
|
||||||
|
return r;
|
||||||
|
}();
|
||||||
|
static TfLiteRegistration reg = {.registration_external = reg_external};
|
||||||
return ®
|
return ®
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* RegisterMaxUnpooling2D() {
|
TfLiteRegistration* RegisterMaxUnpooling2D() {
|
||||||
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
|
static TfLiteRegistrationExternal* reg_external =
|
||||||
|
// Intentionally allocated and never destroyed.
|
||||||
|
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
|
||||||
|
kMaxUnpooling2DOpName,
|
||||||
|
kMaxUnpooling2DOpVersion);
|
||||||
|
static TfLiteRegistration reg = {.registration_external = reg_external};
|
||||||
return ®
|
return ®
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* RegisterConvolution2DTransposeBias() {
|
TfLiteRegistration* RegisterConvolution2DTransposeBias() {
|
||||||
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
|
static TfLiteRegistrationExternal* reg_external =
|
||||||
|
// Intentionally allocated and never destroyed.
|
||||||
|
TfLiteRegistrationExternalCreate(kTfLiteBuiltinCustom,
|
||||||
|
kConvolution2DTransposeBiasOpName,
|
||||||
|
kConvolution2DTransposeBiasOpVersion);
|
||||||
|
static TfLiteRegistration reg = {.registration_external = reg_external};
|
||||||
return ®
|
return ®
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
OpResolver::OpResolver() {
|
OpResolver::OpResolver() {
|
||||||
AddCustom("MaxPoolingWithArgmax2D", RegisterMaxPoolingWithArgmax2D());
|
AddCustom(kMaxPoolingWithArgmax2DOpName, RegisterMaxPoolingWithArgmax2D(),
|
||||||
AddCustom("MaxUnpooling2D", RegisterMaxUnpooling2D());
|
kMaxPoolingWithArgmax2DOpVersion);
|
||||||
AddCustom("Convolution2DTransposeBias", RegisterConvolution2DTransposeBias());
|
AddCustom(kMaxUnpooling2DOpName, RegisterMaxUnpooling2D(),
|
||||||
|
kMaxUnpooling2DOpVersion);
|
||||||
|
AddCustom(kConvolution2DTransposeBiasOpName,
|
||||||
|
RegisterConvolution2DTransposeBias(),
|
||||||
|
kConvolution2DTransposeBiasOpVersion);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
Loading…
Reference in New Issue
Block a user