diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index a3e44c536..84dcda260 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -80,6 +80,7 @@ cc_library( "//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite", "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", + "//mediapipe/tasks/cc/vision/custom_ops:fused_batch_norm", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", "//mediapipe/util/tflite/operations:max_unpooling", diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index b816d8859..04bc75057 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" +#include "mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" #include "mediapipe/util/tflite/operations/max_unpooling.h" @@ -56,6 +57,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER()); AddCustom("RaggedTensorToTensor", mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); + AddCustom("FusedBatchNormV3", + mediapipe::tflite_operations::Register_FusedBatchNorm()); } } // namespace core } // namespace tasks diff --git a/mediapipe/tasks/cc/vision/custom_ops/BUILD b/mediapipe/tasks/cc/vision/custom_ops/BUILD new file mode 100644 index 000000000..71eda50d3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/custom_ops/BUILD @@ -0,0 +1,35 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "fused_batch_norm", + srcs = ["fused_batch_norm.cc"], + hdrs = ["fused_batch_norm.h"], + visibility = [ + "//visibility:public", + ], + deps = + [ + "@eigen_archive//:eigen3", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/c:private_common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) diff --git a/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.cc b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.cc new file mode 100644 index 000000000..b3eccd340 --- /dev/null +++ b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.cc @@ -0,0 +1,296 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h" + +#include + +#include "Eigen/Core" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace mediapipe::tflite_operations { +namespace vision::batch_norm { +namespace { + +using tflite::GetTensorData; + +constexpr int kInputIndex = 0; +constexpr int kInputScaleIndex = 1; +constexpr int kInputOffsetIndex = 2; +constexpr int kInputEstimatedMeanIndex = 3; +constexpr int kInputEstimatedVarIndex = 4; + +constexpr int kOutputIndex = 0; +constexpr int kOutputBatchMeanIndex = 1; +constexpr int kOutputBatchVarIndex = 2; +constexpr int kOutputSavedMeanIndex = 3; +constexpr int kOutputSavedVarIndex = 4; + +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Tensor; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Vec; + typedef Eigen::TensorMap< + Eigen::Tensor> + ConstVec; +}; + +template +void FusedBarchNorm(TfLiteContext* context, TfLiteTensor* x_input, + TfLiteTensor* scale_input, TfLiteTensor* offset_input, + TfLiteTensor* running_mean_input, + TfLiteTensor* running_variance_input, + TfLiteTensor* y_output, TfLiteTensor* running_mean_output, + TfLiteTensor* running_var_output, + TfLiteTensor* saved_batch_mean_output, + TfLiteTensor* saved_batch_var_output, + U exponential_avg_factor, U epsilon) { + const int batches = x_input->dims->data[0]; + const int height = x_input->dims->data[1]; + const int width = x_input->dims->data[2]; + const int depth = x_input->dims->data[3]; + + Eigen::array x_dims = {batches, height, width, depth}; + Eigen::array depth_dims = {depth}; + + const int rest_size = batches * height * width; + + typename TTypes::Tensor x(GetTensorData(x_input), x_dims); + typename TTypes::ConstVec scale(GetTensorData(scale_input), depth_dims); + typename TTypes::ConstVec offset(GetTensorData(offset_input), + depth_dims); + typename TTypes::ConstVec old_mean(GetTensorData(running_mean_input), + depth_dims); + typename TTypes::ConstVec old_variance( + GetTensorData(running_variance_input), depth_dims); + typename TTypes::Tensor y(GetTensorData(y_output), x_dims); + typename TTypes::Vec new_mean(GetTensorData(running_mean_output), + depth_dims); + typename TTypes::Vec new_variance(GetTensorData(running_var_output), + depth_dims); + typename TTypes::Vec saved_batch_mean( + GetTensorData(saved_batch_mean_output), depth_dims); + typename TTypes::Vec saved_batch_var( + GetTensorData(saved_batch_var_output), depth_dims); + + Eigen::DSizes rest_by_depth(rest_size, depth); + Eigen::DSizes tensor_shape(batches, height, width, depth); + + Eigen::IndexList, Eigen::Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList> reduce_dims; + Eigen::IndexList> bcast_spec; + bcast_spec.set(0, rest_size); + + auto x_rest_by_depth = x.reshape(rest_by_depth).template cast(); + const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; + U rest_size_inv = static_cast(1.0f / static_cast(rest_size)); + // This adjustment is for Bessel's correction + U rest_size_adjust = + static_cast(rest_size) / static_cast(rest_size_minus_one); + + Eigen::Tensor batch_mean(depth); + Eigen::Tensor batch_variance(depth); + + batch_mean = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); + auto x_centered = + x_rest_by_depth - batch_mean.reshape(one_by_depth).broadcast(bcast_spec); + + batch_variance = x_centered.square().sum(reduce_dims) * rest_size_inv; + auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale) + .eval() + .reshape(one_by_depth) + .broadcast(bcast_spec); + auto x_scaled = x_centered * scaling_factor; + auto x_shifted = + (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) + .template cast(); + + y.reshape(rest_by_depth) = x_shifted; + if (exponential_avg_factor == U(1.0)) { + saved_batch_var = batch_variance; + saved_batch_mean = batch_mean; + new_variance = batch_variance * rest_size_adjust; + new_mean = batch_mean; + } else { + U one_minus_factor = U(1) - exponential_avg_factor; + saved_batch_var = batch_variance; + saved_batch_mean = batch_mean; + new_variance = one_minus_factor * old_variance + + (exponential_avg_factor * rest_size_adjust) * batch_variance; + new_mean = + one_minus_factor * old_mean + exponential_avg_factor * batch_mean; + } +} + +} // namespace + +// Initializes FusedBatchNorm object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} + +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 5); + TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 6); + + TfLiteTensor* output = tflite::GetOutput(context, node, kOutputIndex); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteTensor* batch_mean = + tflite::GetOutput(context, node, kOutputBatchMeanIndex); + TF_LITE_ENSURE(context, batch_mean != nullptr); + TfLiteTensor* batch_var = + tflite::GetOutput(context, node, kOutputBatchVarIndex); + TF_LITE_ENSURE(context, batch_var != nullptr); + TfLiteTensor* saved_mean = + tflite::GetOutput(context, node, kOutputSavedMeanIndex); + TF_LITE_ENSURE(context, saved_mean != nullptr); + TfLiteTensor* saved_var = + tflite::GetOutput(context, node, kOutputSavedVarIndex); + TF_LITE_ENSURE(context, saved_var != nullptr); + TfLiteTensor* dummy_reserve_space = tflite::GetOutput(context, node, 5); + TF_LITE_ENSURE(context, dummy_reserve_space != nullptr); + + const TfLiteTensor* input = tflite::GetInput(context, node, kInputIndex); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* scale = tflite::GetInput(context, node, kInputScaleIndex); + TF_LITE_ENSURE(context, scale != nullptr); + const TfLiteTensor* offset = + tflite::GetInput(context, node, kInputOffsetIndex); + TF_LITE_ENSURE(context, offset != nullptr); + const TfLiteTensor* estimated_mean = + tflite::GetInput(context, node, kInputEstimatedMeanIndex); + TF_LITE_ENSURE(context, estimated_mean != nullptr); + const TfLiteTensor* estimated_var = + tflite::GetInput(context, node, kInputEstimatedVarIndex); + TF_LITE_ENSURE(context, estimated_var != nullptr); + + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(scale), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(offset), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(estimated_mean), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(estimated_var), 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, scale->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, offset->type, kTfLiteFloat32); + + int batches = input->dims->data[0]; + int height = input->dims->data[1]; + int width = input->dims->data[2]; + int depth = input->dims->data[3]; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = height; + output_size->data[2] = width; + output_size->data[3] = depth; + if (context->ResizeTensor(context, output, output_size) != kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* batch_mean_size = TfLiteIntArrayCreate(1); + batch_mean_size->data[0] = depth; + if (context->ResizeTensor(context, batch_mean, batch_mean_size) != + kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* batch_var_size = TfLiteIntArrayCreate(1); + batch_var_size->data[0] = depth; + if (context->ResizeTensor(context, batch_var, batch_var_size) != kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* saved_mean_size = TfLiteIntArrayCreate(1); + saved_mean_size->data[0] = depth; + if (context->ResizeTensor(context, saved_mean, saved_mean_size) != + kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* saved_var_size = TfLiteIntArrayCreate(1); + saved_var_size->data[0] = depth; + if (context->ResizeTensor(context, saved_var, saved_var_size) != kTfLiteOk) { + return kTfLiteError; + } + TfLiteIntArray* dummy_reserve_size = TfLiteIntArrayCreate(1); + dummy_reserve_size->data[0] = 1; + if (context->ResizeTensor(context, dummy_reserve_space, dummy_reserve_size) != + kTfLiteOk) { + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = tflite::GetInput(context, node, kInputIndex); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* scale = tflite::GetInput(context, node, kInputScaleIndex); + TF_LITE_ENSURE(context, scale != nullptr); + const TfLiteTensor* offset = + tflite::GetInput(context, node, kInputOffsetIndex); + TF_LITE_ENSURE(context, offset != nullptr); + const TfLiteTensor* estimated_mean = + tflite::GetInput(context, node, kInputEstimatedMeanIndex); + TF_LITE_ENSURE(context, estimated_mean != nullptr); + const TfLiteTensor* estimated_var = + tflite::GetInput(context, node, kInputEstimatedVarIndex); + TF_LITE_ENSURE(context, estimated_var != nullptr); + + TfLiteTensor* output = tflite::GetOutput(context, node, kOutputIndex); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteTensor* batch_mean = + tflite::GetOutput(context, node, kOutputBatchMeanIndex); + TF_LITE_ENSURE(context, batch_mean != nullptr); + TfLiteTensor* batch_var = + tflite::GetOutput(context, node, kOutputBatchVarIndex); + TF_LITE_ENSURE(context, batch_var != nullptr); + TfLiteTensor* saved_mean = + tflite::GetOutput(context, node, kOutputSavedMeanIndex); + TF_LITE_ENSURE(context, saved_mean != nullptr); + TfLiteTensor* saved_var = + tflite::GetOutput(context, node, kOutputSavedVarIndex); + TF_LITE_ENSURE(context, saved_var != nullptr); + + FusedBarchNorm( + context, const_cast(input), + const_cast(scale), const_cast(offset), + const_cast(estimated_mean), + const_cast(estimated_var), output, batch_mean, batch_var, + saved_mean, saved_var, /*exponential_avg_factor=*/0.001f, + /*epsilon=*/0.001f); + + return kTfLiteOk; +} +} // namespace vision::batch_norm + +TfLiteRegistration* Register_FusedBatchNorm() { + static TfLiteRegistration r = { + vision::batch_norm::Initialize, vision::batch_norm::Free, + vision::batch_norm::Prepare, vision::batch_norm::Eval}; + return &r; +} + +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h new file mode 100644 index 000000000..98e16ff92 --- /dev/null +++ b/mediapipe/tasks/cc/vision/custom_ops/fused_batch_norm.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_CUSTOM_OPS_FUSED_BATCH_NORM_H_ +#define MEDIAPIPE_TASKS_CC_VISION_CUSTOM_OPS_FUSED_BATCH_NORM_H_ + +#include "tensorflow/lite/core/c/common.h" + +namespace mediapipe::tflite_operations { + +// The FusedBatchNorm op resolver is CPU-friendly only. +TfLiteRegistration* Register_FusedBatchNorm(); + +} // namespace mediapipe::tflite_operations + +#endif // MEDIAPIPE_TASKS_CC_VISION_CUSTOM_OPS_FUSED_BATCH_NORM_H_