From ae606c155042dfa74f78a941e59811de23c85e07 Mon Sep 17 00:00:00 2001 From: Youchuan Hu Date: Wed, 8 Nov 2023 12:51:09 -0800 Subject: [PATCH] Refactor OpenCV path out of TensorsToSegmentationCalculator main file. ProcessCpu() is changed into an OpenCV converter that is owned by the calculator. The calculator should call converter.Convert() to get the conversion result. PiperOrigin-RevId: 580625461 --- mediapipe/calculators/tensor/BUILD | 82 ++++++- .../tensors_to_segmentation_calculator.cc | 223 +++++------------- .../tensors_to_segmentation_converter.h | 43 ++++ ...ensors_to_segmentation_converter_opencv.cc | 157 ++++++++++++ ...tensors_to_segmentation_converter_opencv.h | 31 +++ .../tensor/tensors_to_segmentation_utils.cc | 52 ++++ .../tensor/tensors_to_segmentation_utils.h | 34 +++ .../tensors_to_segmentation_utils_test.cc | 63 +++++ 8 files changed, 513 insertions(+), 172 deletions(-) create mode 100644 mediapipe/calculators/tensor/tensors_to_segmentation_converter.h create mode 100644 mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc create mode 100644 mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h create mode 100644 mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc create mode 100644 mediapipe/calculators/tensor/tensors_to_segmentation_utils.h create mode 100644 mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index ac2ced837..76f5bdbf6 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -1414,6 +1414,8 @@ cc_library( }), deps = [ ":tensors_to_segmentation_calculator_cc_proto", + ":tensors_to_segmentation_converter", + ":tensors_to_segmentation_utils", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", @@ -1421,9 +1423,11 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -1434,6 +1438,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_buffer_format", "//mediapipe/gpu:shader_util", ], }) + selects.with_or({ @@ -1453,13 +1458,86 @@ cc_library( }) + select({ "//mediapipe/framework/port:disable_opencv": [], "//conditions:default": [ - "//mediapipe/framework/formats:image_opencv", - "//mediapipe/framework/port:opencv_imgproc", + ":tensors_to_segmentation_converter_opencv", ], }), alwayslink = 1, ) +cc_library( + name = "tensors_to_segmentation_utils", + srcs = ["tensors_to_segmentation_utils.cc"], + hdrs = ["tensors_to_segmentation_utils.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + "//mediapipe/framework:port", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "tensors_to_segmentation_utils_test", + srcs = ["tensors_to_segmentation_utils_test.cc"], + deps = [ + ":tensors_to_segmentation_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status_matchers", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "tensors_to_segmentation_converter", + hdrs = ["tensors_to_segmentation_converter.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "tensors_to_segmentation_converter_opencv", + srcs = ["tensors_to_segmentation_converter_opencv.cc"], + hdrs = ["tensors_to_segmentation_converter_opencv.h"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + deps = [ + ":tensors_to_segmentation_calculator_cc_proto", + ":tensors_to_segmentation_converter", + ":tensors_to_segmentation_utils", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + cc_test( name = "tensors_to_segmentation_calculator_test", srcs = ["tensors_to_segmentation_calculator_test.cc"], diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 24fd1bd52..90d2e6246 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -12,32 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include +#include #include -#include "absl/strings/str_format.h" -#include "absl/types/span.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_origin.pb.h" -#include "mediapipe/util/resource_util.h" -#include "tensorflow/lite/interpreter.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" #endif // !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_OPENCV -#include "mediapipe/framework/formats/image_opencv.h" -#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h" #endif // !MEDIAPIPE_DISABLE_OPENCV #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 @@ -62,37 +65,9 @@ namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; -// Commonly used to compute the number of blocks to launch in a kernel. -int NumGroups(const int size, const int group_size) { // NOLINT - return (size + group_size - 1) / group_size; -} - -bool CanUseGpu() { -#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED - // TODO: Configure GPU usage policy in individual calculators. - constexpr bool kAllowGpuProcessing = true; - return kAllowGpuProcessing; -#else - return false; -#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED -} - constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kMaskTag[] = "MASK"; - -absl::StatusOr> GetHwcFromDims( - const std::vector& dims) { - if (dims.size() == 3) { - return std::make_tuple(dims[0], dims[1], dims[2]); - } else if (dims.size() == 4) { - // BHWC format check B == 1 - RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; - return std::make_tuple(dims[1], dims[2], dims[3]); - } else { - RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); - } -} } // namespace namespace mediapipe { @@ -156,19 +131,24 @@ class TensorsToSegmentationCalculator : public CalculatorBase { private: absl::Status LoadOptions(CalculatorContext* cc); absl::Status InitGpu(CalculatorContext* cc); - absl::Status ProcessGpu(CalculatorContext* cc); - absl::Status ProcessCpu(CalculatorContext* cc); + absl::Status ProcessGpu(CalculatorContext* cc, + const std::vector& input_tensors, + std::tuple hwc, int output_width, + int output_height); void GlRender(); bool DoesGpuTextureStartAtBottom() { return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; } + absl::Status InitConverterIfNecessary() { + if (!cpu_converter_) { + MP_ASSIGN_OR_RETURN(cpu_converter_, CreateOpenCvConverter(options_)); + } + return absl::OkStatus(); + } -#if !MEDIAPIPE_DISABLE_OPENCV - template - absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); -#endif // !MEDIAPIPE_DISABLE_OPENCV - ::mediapipe::TensorsToSegmentationCalculatorOptions options_; + mediapipe::TensorsToSegmentationCalculatorOptions options_; + std::unique_ptr cpu_converter_; #if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; @@ -261,7 +241,7 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); int tensor_channels = std::get<2>(hwc); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions; switch (options_.activation()) { case Options::NONE: RET_CHECK_EQ(tensor_channels, 1); @@ -275,6 +255,17 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { } } + // Get dimensions. + MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) { + const auto& size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + if (use_gpu) { #if !MEDIAPIPE_DISABLE_GPU if (!gpu_initialized_) { @@ -286,16 +277,25 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { #endif // !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { - MP_RETURN_IF_ERROR(ProcessGpu(cc)); - return absl::OkStatus(); - })); + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc, &input_tensors, output_width, + output_height, hwc]() -> absl::Status { + MP_RETURN_IF_ERROR( + ProcessGpu(cc, input_tensors, hwc, output_width, output_height)); + return absl::OkStatus(); + })); #else RET_CHECK_FAIL() << "GPU processing disabled."; #endif // !MEDIAPIPE_DISABLE_GPU } else { #if !MEDIAPIPE_DISABLE_OPENCV - MP_RETURN_IF_ERROR(ProcessCpu(cc)); + // Lazily initialize converter. + MP_RETURN_IF_ERROR(InitConverterIfNecessary()); + MP_ASSIGN_OR_RETURN( + std::unique_ptr output_mask, + cpu_converter_->Convert(input_tensors, output_width, output_height)); + cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), + cc->InputTimestamp()); #else RET_CHECK_FAIL() << "OpenCV processing disabled."; #endif // !MEDIAPIPE_DISABLE_OPENCV @@ -329,132 +329,15 @@ absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status TensorsToSegmentationCalculator::ProcessCpu( - CalculatorContext* cc) { -#if !MEDIAPIPE_DISABLE_OPENCV - // Get input streams, and dimensions. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) { - const auto& size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } - - // Create initial working mask. - cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); - - // Wrap input tensor. - auto raw_input_tensor = &input_tensors[0]; - auto raw_input_view = raw_input_tensor->GetCpuReadView(); - const float* raw_input_data = raw_input_view.buffer(); - cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), - CV_MAKETYPE(CV_32F, tensor_channels), - const_cast(raw_input_data)); - - // Process mask tensor and apply activation function. - if (tensor_channels == 2) { - MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); - } else if (tensor_channels == 1) { - RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != - options_.activation()); // Requires 2 channels. - if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == - options_.activation()) // Pass-through optimization. - tensor_mat.copyTo(small_mask_mat); - else - MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); - } else { - RET_CHECK_FAIL() << "Unsupported number of tensor channels " - << tensor_channels; - } - - // Send out image as CPU packet. - std::shared_ptr mask_frame = std::make_shared( - ImageFormat::VEC32F1, output_width, output_height); - std::unique_ptr output_mask = absl::make_unique(mask_frame); - auto output_mat = formats::MatView(output_mask.get()); - // Upsample small mask into output. - cv::resize(small_mask_mat, *output_mat, - cv::Size(output_width, output_height)); - cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); -#endif // !MEDIAPIPE_DISABLE_OPENCV - - return absl::OkStatus(); -} - -#if !MEDIAPIPE_DISABLE_OPENCV -template -absl::Status TensorsToSegmentationCalculator::ApplyActivation( - cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { - // Configure activation function. - const int output_layer_index = options_.output_layer_index(); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; - const auto activation_fn = [&](const cv::Vec2f& mask_value) { - float new_mask_value = 0; - // TODO consider moving switch out of the loop, - // and also avoid float/Vec2f casting. - switch (options_.activation()) { - case Options::NONE: { - new_mask_value = mask_value[0]; - break; - } - case Options::SIGMOID: { - const float pixel0 = mask_value[0]; - new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); - break; - } - case Options::SOFTMAX: { - const float pixel0 = mask_value[0]; - const float pixel1 = mask_value[1]; - const float max_pixel = std::max(pixel0, pixel1); - const float min_pixel = std::min(pixel0, pixel1); - const float softmax_denom = - /*exp(max_pixel - max_pixel)=*/1.0f + - std::exp(min_pixel - max_pixel); - new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) / - softmax_denom; - break; - } - } - return new_mask_value; - }; - - // Process mask tensor. - for (int i = 0; i < tensor_mat.rows; ++i) { - for (int j = 0; j < tensor_mat.cols; ++j) { - const T& input_pix = tensor_mat.at(i, j); - const float mask_value = activation_fn(input_pix); - small_mask_mat->at(i, j) = mask_value; - } - } - - return absl::OkStatus(); -} -#endif // !MEDIAPIPE_DISABLE_OPENCV - // Steps: // 1. receive tensor // 2. process segmentation tensor into small mask // 3. upsample small mask into output mask to be same size as input image absl::Status TensorsToSegmentationCalculator::ProcessGpu( - CalculatorContext* cc) { + CalculatorContext* cc, const std::vector& input_tensors, + std::tuple hwc, int output_width, int output_height) { #if !MEDIAPIPE_DISABLE_GPU - // Get input streams, and dimensions. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) { - const auto& size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } // Create initial working mask texture. #if !(MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31) @@ -632,7 +515,7 @@ void TensorsToSegmentationCalculator::GlRender() { absl::Status TensorsToSegmentationCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. - options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); + options_ = cc->Options(); return absl::OkStatus(); } @@ -826,7 +709,7 @@ void main() { #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Shader defines. - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions; const std::string output_layer_index = "\n#define OUTPUT_LAYER_INDEX int(" + std::to_string(options_.output_layer_index()) + ")"; diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_converter.h b/mediapipe/calculators/tensor/tensors_to_segmentation_converter.h new file mode 100644 index 000000000..61d95dfe0 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_converter.h @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { + +class TensorsToSegmentationConverter { + public: + virtual ~TensorsToSegmentationConverter() = default; + + // Converts tensors to image mask. + // Returns a unique pointer containing the converted image. + // @input_tensors contains the tensors needed to be processed. + // @output_width/height describes output dimensions to reshape the output mask + // into. + virtual absl::StatusOr> Convert( + const std::vector& input_tensors, int output_width, + int output_height) = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_ diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc new file mode 100644 index 000000000..1ee2e172b --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc @@ -0,0 +1,157 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" + +namespace mediapipe { +namespace { + +class OpenCvProcessor : public TensorsToSegmentationConverter { + public: + absl::Status Init(const TensorsToSegmentationCalculatorOptions& options) { + options_ = options; + return absl::OkStatus(); + } + + absl::StatusOr> Convert( + const std::vector& input_tensors, int output_width, + int output_height) override; + + private: + template + absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); + + TensorsToSegmentationCalculatorOptions options_; +}; + +absl::StatusOr> OpenCvProcessor::Convert( + const std::vector& input_tensors, int output_width, + int output_height) { + MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + // Create initial working mask. + cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float* raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + // Process mask tensor and apply activation function. + if (tensor_channels == 2) { + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else if (tensor_channels == 1) { + RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != + options_.activation()); // Requires 2 channels. + if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == + options_.activation()) // Pass-through optimization. + tensor_mat.copyTo(small_mask_mat); + else + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else { + RET_CHECK_FAIL() << "Unsupported number of tensor channels " + << tensor_channels; + } + + // Send out image as CPU packet. + std::shared_ptr mask_frame = std::make_shared( + ImageFormat::VEC32F1, output_width, output_height); + auto output_mask = std::make_unique(mask_frame); + auto output_mat = formats::MatView(output_mask.get()); + // Upsample small mask into output. + cv::resize(small_mask_mat, *output_mat, + cv::Size(output_width, output_height)); + return output_mask; +} + +template +absl::Status OpenCvProcessor::ApplyActivation(cv::Mat& tensor_mat, + cv::Mat* small_mask_mat) { + // Configure activation function. + const int output_layer_index = options_.output_layer_index(); + using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions; + const auto activation_fn = [&](const cv::Vec2f& mask_value) { + float new_mask_value = 0; + // TODO consider moving switch out of the loop, + // and also avoid float/Vec2f casting. + switch (options_.activation()) { + case Options::NONE: { + new_mask_value = mask_value[0]; + break; + } + case Options::SIGMOID: { + const float pixel0 = mask_value[0]; + new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); + break; + } + case Options::SOFTMAX: { + const float pixel0 = mask_value[0]; + const float pixel1 = mask_value[1]; + const float max_pixel = std::max(pixel0, pixel1); + const float min_pixel = std::min(pixel0, pixel1); + const float softmax_denom = + /*exp(max_pixel - max_pixel)=*/1.0f + + std::exp(min_pixel - max_pixel); + new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) / + softmax_denom; + break; + } + } + return new_mask_value; + }; + + // Process mask tensor. + for (int i = 0; i < tensor_mat.rows; ++i) { + for (int j = 0; j < tensor_mat.cols; ++j) { + const T& input_pix = tensor_mat.at(i, j); + const float mask_value = activation_fn(input_pix); + small_mask_mat->at(i, j) = mask_value; + } + } + + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> +CreateOpenCvConverter(const TensorsToSegmentationCalculatorOptions& options) { + auto converter = std::make_unique(); + MP_RETURN_IF_ERROR(converter->Init(options)); + return converter; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h new file mode 100644 index 000000000..3ae41b5e0 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h @@ -0,0 +1,31 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h" + +namespace mediapipe { +// Creates OpenCV tensors-to-segmentation converter. +absl::StatusOr> +CreateOpenCvConverter( + const mediapipe::TensorsToSegmentationCalculatorOptions& options); +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_ diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc new file mode 100644 index 000000000..ab1e9c139 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc @@ -0,0 +1,52 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +int NumGroups(int size, int group_size) { + return (size + group_size - 1) / group_size; +} + +bool CanUseGpu() { +#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED + // TODO: Configure GPU usage policy in individual calculators. + constexpr bool kAllowGpuProcessing = true; + return kAllowGpuProcessing; +#else + return false; +#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED +} + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims) { + if (dims.size() == 3) { + return std::make_tuple(dims[0], dims[1], dims[2]); + } else if (dims.size() == 4) { + // BHWC format check B == 1 + RET_CHECK_EQ(dims[0], 1) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } else { + RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); + } +} +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_utils.h b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.h new file mode 100644 index 000000000..44893073b --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.h @@ -0,0 +1,34 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace mediapipe { + +// Commonly used to compute the number of blocks to launch in a kernel. +int NumGroups(const int size, const int group_size); // NOLINT + +bool CanUseGpu(); + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims); +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_ diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc new file mode 100644 index 000000000..5535d159d --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc @@ -0,0 +1,63 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::testing::HasSubstr; + +TEST(TensorsToSegmentationUtilsTest, NumGroupsWorksProperly) { + EXPECT_EQ(NumGroups(13, 4), 4); + EXPECT_EQ(NumGroups(4, 13), 1); +} + +TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsWorksProperly) { + std::vector dims_3 = {2, 3, 4}; + absl::StatusOr> result_1 = GetHwcFromDims(dims_3); + MP_ASSERT_OK(result_1); + EXPECT_EQ(result_1.value(), (std::make_tuple(2, 3, 4))); + std::vector dims_4 = {1, 3, 4, 5}; + absl::StatusOr> result_2 = GetHwcFromDims(dims_4); + MP_ASSERT_OK(result_2); + EXPECT_EQ(result_2.value(), (std::make_tuple(3, 4, 5))); +} + +TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsBatchCheckFail) { + std::vector dims_4 = {2, 3, 4, 5}; + absl::StatusOr> result = GetHwcFromDims(dims_4); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), + HasSubstr("Expected batch to be 1 for BHWC heatmap")); +} + +TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsInvalidShape) { + std::vector dims_5 = {1, 2, 3, 4, 5}; + absl::StatusOr> result = GetHwcFromDims(dims_5); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), + HasSubstr("Invalid shape for segmentation tensor")); +} + +} // namespace +} // namespace mediapipe