diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index ac2ced837..c0db4e35b 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -1405,15 +1405,10 @@ mediapipe_proto_library( cc_library( name = "tensors_to_segmentation_calculator", srcs = ["tensors_to_segmentation_calculator.cc"], - copts = select({ - "//mediapipe:apple": [ - "-x objective-c++", - "-fobjc-arc", # enable reference-counting - ], - "//conditions:default": [], - }), deps = [ ":tensors_to_segmentation_calculator_cc_proto", + ":tensors_to_segmentation_converter", + ":tensors_to_segmentation_utils", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", @@ -1421,9 +1416,11 @@ cc_library( "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -1434,6 +1431,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_buffer_format", "//mediapipe/gpu:shader_util", ], }) + selects.with_or({ @@ -1453,13 +1451,65 @@ cc_library( }) + select({ "//mediapipe/framework/port:disable_opencv": [], "//conditions:default": [ - "//mediapipe/framework/formats:image_opencv", - "//mediapipe/framework/port:opencv_imgproc", + ":tensors_to_segmentation_converter_opencv", ], }), alwayslink = 1, ) +cc_library( + name = "tensors_to_segmentation_utils", + srcs = ["tensors_to_segmentation_utils.cc"], + hdrs = ["tensors_to_segmentation_utils.h"], + deps = [ + "//mediapipe/framework:port", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "tensors_to_segmentation_utils_test", + srcs = ["tensors_to_segmentation_utils_test.cc"], + deps = [ + ":tensors_to_segmentation_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status_matchers", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "tensors_to_segmentation_converter", + hdrs = ["tensors_to_segmentation_converter.h"], + deps = [ + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "tensors_to_segmentation_converter_opencv", + srcs = ["tensors_to_segmentation_converter_opencv.cc"], + hdrs = ["tensors_to_segmentation_converter_opencv.h"], + deps = [ + ":tensors_to_segmentation_calculator_cc_proto", + ":tensors_to_segmentation_converter", + ":tensors_to_segmentation_utils", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + cc_test( name = "tensors_to_segmentation_calculator_test", srcs = ["tensors_to_segmentation_calculator_test.cc"], diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 24fd1bd52..6164c7b0a 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -12,32 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include +#include #include -#include "absl/strings/str_format.h" -#include "absl/types/span.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/gpu/gpu_origin.pb.h" -#include "mediapipe/util/resource_util.h" -#include "tensorflow/lite/interpreter.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" -#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/shader_util.h" #endif // !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_OPENCV -#include "mediapipe/framework/formats/image_opencv.h" -#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h" #endif // !MEDIAPIPE_DISABLE_OPENCV #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 @@ -62,37 +65,9 @@ namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; -// Commonly used to compute the number of blocks to launch in a kernel. -int NumGroups(const int size, const int group_size) { // NOLINT - return (size + group_size - 1) / group_size; -} - -bool CanUseGpu() { -#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED - // TODO: Configure GPU usage policy in individual calculators. - constexpr bool kAllowGpuProcessing = true; - return kAllowGpuProcessing; -#else - return false; -#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED -} - constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kMaskTag[] = "MASK"; - -absl::StatusOr> GetHwcFromDims( - const std::vector& dims) { - if (dims.size() == 3) { - return std::make_tuple(dims[0], dims[1], dims[2]); - } else if (dims.size() == 4) { - // BHWC format check B == 1 - RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; - return std::make_tuple(dims[1], dims[2], dims[3]); - } else { - RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); - } -} } // namespace namespace mediapipe { @@ -156,19 +131,28 @@ class TensorsToSegmentationCalculator : public CalculatorBase { private: absl::Status LoadOptions(CalculatorContext* cc); absl::Status InitGpu(CalculatorContext* cc); - absl::Status ProcessGpu(CalculatorContext* cc); - absl::Status ProcessCpu(CalculatorContext* cc); + absl::Status ProcessGpu(CalculatorContext* cc, + const std::vector& input_tensors, + std::tuple hwc, int output_width, + int output_height); void GlRender(); bool DoesGpuTextureStartAtBottom() { return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; } - + absl::Status InitConverterIfNecessary() { #if !MEDIAPIPE_DISABLE_OPENCV - template - absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); + if (!cpu_converter_) { + MP_ASSIGN_OR_RETURN(cpu_converter_, CreateOpenCvConverter(options_)); + } +#else + RET_CHECK_FAIL() << "OpenCV processing disabled."; #endif // !MEDIAPIPE_DISABLE_OPENCV - ::mediapipe::TensorsToSegmentationCalculatorOptions options_; + return absl::OkStatus(); + } + + mediapipe::TensorsToSegmentationCalculatorOptions options_; + std::unique_ptr cpu_converter_; #if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; @@ -261,7 +245,7 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); int tensor_channels = std::get<2>(hwc); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions; switch (options_.activation()) { case Options::NONE: RET_CHECK_EQ(tensor_channels, 1); @@ -275,6 +259,17 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { } } + // Get dimensions. + MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) { + const auto& size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + if (use_gpu) { #if !MEDIAPIPE_DISABLE_GPU if (!gpu_initialized_) { @@ -286,16 +281,25 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { #endif // !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { - MP_RETURN_IF_ERROR(ProcessGpu(cc)); - return absl::OkStatus(); - })); + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc, &input_tensors, output_width, + output_height, hwc]() -> absl::Status { + MP_RETURN_IF_ERROR( + ProcessGpu(cc, input_tensors, hwc, output_width, output_height)); + return absl::OkStatus(); + })); #else RET_CHECK_FAIL() << "GPU processing disabled."; #endif // !MEDIAPIPE_DISABLE_GPU } else { #if !MEDIAPIPE_DISABLE_OPENCV - MP_RETURN_IF_ERROR(ProcessCpu(cc)); + // Lazily initialize converter. + MP_RETURN_IF_ERROR(InitConverterIfNecessary()); + MP_ASSIGN_OR_RETURN( + std::unique_ptr output_mask, + cpu_converter_->Convert(input_tensors, output_width, output_height)); + cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), + cc->InputTimestamp()); #else RET_CHECK_FAIL() << "OpenCV processing disabled."; #endif // !MEDIAPIPE_DISABLE_OPENCV @@ -329,132 +333,15 @@ absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) { return absl::OkStatus(); } -absl::Status TensorsToSegmentationCalculator::ProcessCpu( - CalculatorContext* cc) { -#if !MEDIAPIPE_DISABLE_OPENCV - // Get input streams, and dimensions. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) { - const auto& size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } - - // Create initial working mask. - cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); - - // Wrap input tensor. - auto raw_input_tensor = &input_tensors[0]; - auto raw_input_view = raw_input_tensor->GetCpuReadView(); - const float* raw_input_data = raw_input_view.buffer(); - cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), - CV_MAKETYPE(CV_32F, tensor_channels), - const_cast(raw_input_data)); - - // Process mask tensor and apply activation function. - if (tensor_channels == 2) { - MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); - } else if (tensor_channels == 1) { - RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != - options_.activation()); // Requires 2 channels. - if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == - options_.activation()) // Pass-through optimization. - tensor_mat.copyTo(small_mask_mat); - else - MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); - } else { - RET_CHECK_FAIL() << "Unsupported number of tensor channels " - << tensor_channels; - } - - // Send out image as CPU packet. - std::shared_ptr mask_frame = std::make_shared( - ImageFormat::VEC32F1, output_width, output_height); - std::unique_ptr output_mask = absl::make_unique(mask_frame); - auto output_mat = formats::MatView(output_mask.get()); - // Upsample small mask into output. - cv::resize(small_mask_mat, *output_mat, - cv::Size(output_width, output_height)); - cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); -#endif // !MEDIAPIPE_DISABLE_OPENCV - - return absl::OkStatus(); -} - -#if !MEDIAPIPE_DISABLE_OPENCV -template -absl::Status TensorsToSegmentationCalculator::ApplyActivation( - cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { - // Configure activation function. - const int output_layer_index = options_.output_layer_index(); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; - const auto activation_fn = [&](const cv::Vec2f& mask_value) { - float new_mask_value = 0; - // TODO consider moving switch out of the loop, - // and also avoid float/Vec2f casting. - switch (options_.activation()) { - case Options::NONE: { - new_mask_value = mask_value[0]; - break; - } - case Options::SIGMOID: { - const float pixel0 = mask_value[0]; - new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); - break; - } - case Options::SOFTMAX: { - const float pixel0 = mask_value[0]; - const float pixel1 = mask_value[1]; - const float max_pixel = std::max(pixel0, pixel1); - const float min_pixel = std::min(pixel0, pixel1); - const float softmax_denom = - /*exp(max_pixel - max_pixel)=*/1.0f + - std::exp(min_pixel - max_pixel); - new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) / - softmax_denom; - break; - } - } - return new_mask_value; - }; - - // Process mask tensor. - for (int i = 0; i < tensor_mat.rows; ++i) { - for (int j = 0; j < tensor_mat.cols; ++j) { - const T& input_pix = tensor_mat.at(i, j); - const float mask_value = activation_fn(input_pix); - small_mask_mat->at(i, j) = mask_value; - } - } - - return absl::OkStatus(); -} -#endif // !MEDIAPIPE_DISABLE_OPENCV - // Steps: // 1. receive tensor // 2. process segmentation tensor into small mask // 3. upsample small mask into output mask to be same size as input image absl::Status TensorsToSegmentationCalculator::ProcessGpu( - CalculatorContext* cc) { + CalculatorContext* cc, const std::vector& input_tensors, + std::tuple hwc, int output_width, int output_height) { #if !MEDIAPIPE_DISABLE_GPU - // Get input streams, and dimensions. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) { - const auto& size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } // Create initial working mask texture. #if !(MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31) @@ -632,7 +519,7 @@ void TensorsToSegmentationCalculator::GlRender() { absl::Status TensorsToSegmentationCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. - options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); + options_ = cc->Options(); return absl::OkStatus(); } @@ -826,7 +713,7 @@ void main() { #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Shader defines. - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions; const std::string output_layer_index = "\n#define OUTPUT_LAYER_INDEX int(" + std::to_string(options_.output_layer_index()) + ")"; diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_converter.h b/mediapipe/calculators/tensor/tensors_to_segmentation_converter.h new file mode 100644 index 000000000..61d95dfe0 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_converter.h @@ -0,0 +1,43 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { + +class TensorsToSegmentationConverter { + public: + virtual ~TensorsToSegmentationConverter() = default; + + // Converts tensors to image mask. + // Returns a unique pointer containing the converted image. + // @input_tensors contains the tensors needed to be processed. + // @output_width/height describes output dimensions to reshape the output mask + // into. + virtual absl::StatusOr> Convert( + const std::vector& input_tensors, int output_width, + int output_height) = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_ diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc new file mode 100644 index 000000000..1ee2e172b --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.cc @@ -0,0 +1,157 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status_macros.h" + +namespace mediapipe { +namespace { + +class OpenCvProcessor : public TensorsToSegmentationConverter { + public: + absl::Status Init(const TensorsToSegmentationCalculatorOptions& options) { + options_ = options; + return absl::OkStatus(); + } + + absl::StatusOr> Convert( + const std::vector& input_tensors, int output_width, + int output_height) override; + + private: + template + absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); + + TensorsToSegmentationCalculatorOptions options_; +}; + +absl::StatusOr> OpenCvProcessor::Convert( + const std::vector& input_tensors, int output_width, + int output_height) { + MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + // Create initial working mask. + cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float* raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + // Process mask tensor and apply activation function. + if (tensor_channels == 2) { + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else if (tensor_channels == 1) { + RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != + options_.activation()); // Requires 2 channels. + if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == + options_.activation()) // Pass-through optimization. + tensor_mat.copyTo(small_mask_mat); + else + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else { + RET_CHECK_FAIL() << "Unsupported number of tensor channels " + << tensor_channels; + } + + // Send out image as CPU packet. + std::shared_ptr mask_frame = std::make_shared( + ImageFormat::VEC32F1, output_width, output_height); + auto output_mask = std::make_unique(mask_frame); + auto output_mat = formats::MatView(output_mask.get()); + // Upsample small mask into output. + cv::resize(small_mask_mat, *output_mat, + cv::Size(output_width, output_height)); + return output_mask; +} + +template +absl::Status OpenCvProcessor::ApplyActivation(cv::Mat& tensor_mat, + cv::Mat* small_mask_mat) { + // Configure activation function. + const int output_layer_index = options_.output_layer_index(); + using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions; + const auto activation_fn = [&](const cv::Vec2f& mask_value) { + float new_mask_value = 0; + // TODO consider moving switch out of the loop, + // and also avoid float/Vec2f casting. + switch (options_.activation()) { + case Options::NONE: { + new_mask_value = mask_value[0]; + break; + } + case Options::SIGMOID: { + const float pixel0 = mask_value[0]; + new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); + break; + } + case Options::SOFTMAX: { + const float pixel0 = mask_value[0]; + const float pixel1 = mask_value[1]; + const float max_pixel = std::max(pixel0, pixel1); + const float min_pixel = std::min(pixel0, pixel1); + const float softmax_denom = + /*exp(max_pixel - max_pixel)=*/1.0f + + std::exp(min_pixel - max_pixel); + new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) / + softmax_denom; + break; + } + } + return new_mask_value; + }; + + // Process mask tensor. + for (int i = 0; i < tensor_mat.rows; ++i) { + for (int j = 0; j < tensor_mat.cols; ++j) { + const T& input_pix = tensor_mat.at(i, j); + const float mask_value = activation_fn(input_pix); + small_mask_mat->at(i, j) = mask_value; + } + } + + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> +CreateOpenCvConverter(const TensorsToSegmentationCalculatorOptions& options) { + auto converter = std::make_unique(); + MP_RETURN_IF_ERROR(converter->Init(options)); + return converter; +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h new file mode 100644 index 000000000..3ae41b5e0 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h @@ -0,0 +1,31 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h" + +namespace mediapipe { +// Creates OpenCV tensors-to-segmentation converter. +absl::StatusOr> +CreateOpenCvConverter( + const mediapipe::TensorsToSegmentationCalculatorOptions& options); +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_ diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc new file mode 100644 index 000000000..ab1e9c139 --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.cc @@ -0,0 +1,52 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +int NumGroups(int size, int group_size) { + return (size + group_size - 1) / group_size; +} + +bool CanUseGpu() { +#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED + // TODO: Configure GPU usage policy in individual calculators. + constexpr bool kAllowGpuProcessing = true; + return kAllowGpuProcessing; +#else + return false; +#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED +} + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims) { + if (dims.size() == 3) { + return std::make_tuple(dims[0], dims[1], dims[2]); + } else if (dims.size() == 4) { + // BHWC format check B == 1 + RET_CHECK_EQ(dims[0], 1) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } else { + RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); + } +} +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_utils.h b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.h new file mode 100644 index 000000000..44893073b --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_utils.h @@ -0,0 +1,34 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace mediapipe { + +// Commonly used to compute the number of blocks to launch in a kernel. +int NumGroups(const int size, const int group_size); // NOLINT + +bool CanUseGpu(); + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims); +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_ diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc new file mode 100644 index 000000000..5535d159d --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_utils_test.cc @@ -0,0 +1,63 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::testing::HasSubstr; + +TEST(TensorsToSegmentationUtilsTest, NumGroupsWorksProperly) { + EXPECT_EQ(NumGroups(13, 4), 4); + EXPECT_EQ(NumGroups(4, 13), 1); +} + +TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsWorksProperly) { + std::vector dims_3 = {2, 3, 4}; + absl::StatusOr> result_1 = GetHwcFromDims(dims_3); + MP_ASSERT_OK(result_1); + EXPECT_EQ(result_1.value(), (std::make_tuple(2, 3, 4))); + std::vector dims_4 = {1, 3, 4, 5}; + absl::StatusOr> result_2 = GetHwcFromDims(dims_4); + MP_ASSERT_OK(result_2); + EXPECT_EQ(result_2.value(), (std::make_tuple(3, 4, 5))); +} + +TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsBatchCheckFail) { + std::vector dims_4 = {2, 3, 4, 5}; + absl::StatusOr> result = GetHwcFromDims(dims_4); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), + HasSubstr("Expected batch to be 1 for BHWC heatmap")); +} + +TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsInvalidShape) { + std::vector dims_5 = {1, 2, 3, 4, 5}; + absl::StatusOr> result = GetHwcFromDims(dims_5); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), + HasSubstr("Invalid shape for segmentation tensor")); +} + +} // namespace +} // namespace mediapipe