Refactor OpenCV path out of TensorsToSegmentationCalculator main file.

ProcessCpu() is changed into an OpenCV converter that is owned by the calculator. The calculator should call converter.Convert() to get the conversion result.

PiperOrigin-RevId: 581103226
This commit is contained in:
MediaPipe Team 2023-11-09 20:04:36 -08:00 committed by Copybara-Service
parent 1038f8176d
commit fd4859c178
8 changed files with 179 additions and 496 deletions

View File

@ -1405,10 +1405,15 @@ mediapipe_proto_library(
cc_library(
name = "tensors_to_segmentation_calculator",
srcs = ["tensors_to_segmentation_calculator.cc"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
deps = [
":tensors_to_segmentation_calculator_cc_proto",
":tensors_to_segmentation_converter",
":tensors_to_segmentation_utils",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:port",
@ -1416,11 +1421,9 @@ cc_library(
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/util:resource_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@ -1431,7 +1434,6 @@ cc_library(
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_buffer_format",
"//mediapipe/gpu:shader_util",
],
}) + selects.with_or({
@ -1451,65 +1453,13 @@ cc_library(
}) + select({
"//mediapipe/framework/port:disable_opencv": [],
"//conditions:default": [
":tensors_to_segmentation_converter_opencv",
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/port:opencv_imgproc",
],
}),
alwayslink = 1,
)
cc_library(
name = "tensors_to_segmentation_utils",
srcs = ["tensors_to_segmentation_utils.cc"],
hdrs = ["tensors_to_segmentation_utils.h"],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status:statusor",
],
)
cc_test(
name = "tensors_to_segmentation_utils_test",
srcs = ["tensors_to_segmentation_utils_test.cc"],
deps = [
":tensors_to_segmentation_utils",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "tensors_to_segmentation_converter",
hdrs = ["tensors_to_segmentation_converter.h"],
deps = [
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "tensors_to_segmentation_converter_opencv",
srcs = ["tensors_to_segmentation_converter_opencv.cc"],
hdrs = ["tensors_to_segmentation_converter_opencv.h"],
deps = [
":tensors_to_segmentation_calculator_cc_proto",
":tensors_to_segmentation_converter",
":tensors_to_segmentation_utils",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)
cc_test(
name = "tensors_to_segmentation_calculator_test",
srcs = ["tensors_to_segmentation_calculator_test.cc"],

View File

@ -12,35 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h"
#endif // !MEDIAPIPE_DISABLE_GPU
#if !MEDIAPIPE_DISABLE_OPENCV
#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
@ -65,9 +62,37 @@ namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
// Commonly used to compute the number of blocks to launch in a kernel.
int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size;
}
bool CanUseGpu() {
#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED
// TODO: Configure GPU usage policy in individual calculators.
constexpr bool kAllowGpuProcessing = true;
return kAllowGpuProcessing;
#else
return false;
#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED
}
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kMaskTag[] = "MASK";
absl::StatusOr<std::tuple<int, int, int>> GetHwcFromDims(
const std::vector<int>& dims) {
if (dims.size() == 3) {
return std::make_tuple(dims[0], dims[1], dims[2]);
} else if (dims.size() == 4) {
// BHWC format check B == 1
RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap";
return std::make_tuple(dims[1], dims[2], dims[3]);
} else {
RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size();
}
}
} // namespace
namespace mediapipe {
@ -131,28 +156,19 @@ class TensorsToSegmentationCalculator : public CalculatorBase {
private:
absl::Status LoadOptions(CalculatorContext* cc);
absl::Status InitGpu(CalculatorContext* cc);
absl::Status ProcessGpu(CalculatorContext* cc,
const std::vector<Tensor>& input_tensors,
std::tuple<int, int, int> hwc, int output_width,
int output_height);
absl::Status ProcessGpu(CalculatorContext* cc);
absl::Status ProcessCpu(CalculatorContext* cc);
void GlRender();
bool DoesGpuTextureStartAtBottom() {
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
}
absl::Status InitConverterIfNecessary() {
#if !MEDIAPIPE_DISABLE_OPENCV
if (!cpu_converter_) {
MP_ASSIGN_OR_RETURN(cpu_converter_, CreateOpenCvConverter(options_));
}
#else
RET_CHECK_FAIL() << "OpenCV processing disabled.";
#endif // !MEDIAPIPE_DISABLE_OPENCV
return absl::OkStatus();
}
mediapipe::TensorsToSegmentationCalculatorOptions options_;
std::unique_ptr<TensorsToSegmentationConverter> cpu_converter_;
#if !MEDIAPIPE_DISABLE_OPENCV
template <class T>
absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat);
#endif // !MEDIAPIPE_DISABLE_OPENCV
::mediapipe::TensorsToSegmentationCalculatorOptions options_;
#if !MEDIAPIPE_DISABLE_GPU
mediapipe::GlCalculatorHelper gpu_helper_;
@ -245,7 +261,7 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) {
MP_ASSIGN_OR_RETURN(auto hwc,
GetHwcFromDims(input_tensors[0].shape().dims));
int tensor_channels = std::get<2>(hwc);
using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions;
typedef mediapipe::TensorsToSegmentationCalculatorOptions Options;
switch (options_.activation()) {
case Options::NONE:
RET_CHECK_EQ(tensor_channels, 1);
@ -259,17 +275,6 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) {
}
}
// Get dimensions.
MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims));
auto [tensor_height, tensor_width, tensor_channels] = hwc;
int output_width = tensor_width, output_height = tensor_height;
if (cc->Inputs().HasTag(kOutputSizeTag)) {
const auto& size =
cc->Inputs().Tag(kOutputSizeTag).Get<std::pair<int, int>>();
output_width = size.first;
output_height = size.second;
}
if (use_gpu) {
#if !MEDIAPIPE_DISABLE_GPU
if (!gpu_initialized_) {
@ -281,25 +286,16 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) {
#endif // !MEDIAPIPE_DISABLE_GPU
#if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc, &input_tensors, output_width,
output_height, hwc]() -> absl::Status {
MP_RETURN_IF_ERROR(
ProcessGpu(cc, input_tensors, hwc, output_width, output_height));
return absl::OkStatus();
}));
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status {
MP_RETURN_IF_ERROR(ProcessGpu(cc));
return absl::OkStatus();
}));
#else
RET_CHECK_FAIL() << "GPU processing disabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
} else {
#if !MEDIAPIPE_DISABLE_OPENCV
// Lazily initialize converter.
MP_RETURN_IF_ERROR(InitConverterIfNecessary());
MP_ASSIGN_OR_RETURN(
std::unique_ptr<Image> output_mask,
cpu_converter_->Convert(input_tensors, output_width, output_height));
cc->Outputs().Tag(kMaskTag).Add(output_mask.release(),
cc->InputTimestamp());
MP_RETURN_IF_ERROR(ProcessCpu(cc));
#else
RET_CHECK_FAIL() << "OpenCV processing disabled.";
#endif // !MEDIAPIPE_DISABLE_OPENCV
@ -333,15 +329,132 @@ absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) {
return absl::OkStatus();
}
absl::Status TensorsToSegmentationCalculator::ProcessCpu(
CalculatorContext* cc) {
#if !MEDIAPIPE_DISABLE_OPENCV
// Get input streams, and dimensions.
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims));
auto [tensor_height, tensor_width, tensor_channels] = hwc;
int output_width = tensor_width, output_height = tensor_height;
if (cc->Inputs().HasTag(kOutputSizeTag)) {
const auto& size =
cc->Inputs().Tag(kOutputSizeTag).Get<std::pair<int, int>>();
output_width = size.first;
output_height = size.second;
}
// Create initial working mask.
cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1);
// Wrap input tensor.
auto raw_input_tensor = &input_tensors[0];
auto raw_input_view = raw_input_tensor->GetCpuReadView();
const float* raw_input_data = raw_input_view.buffer<float>();
cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height),
CV_MAKETYPE(CV_32F, tensor_channels),
const_cast<float*>(raw_input_data));
// Process mask tensor and apply activation function.
if (tensor_channels == 2) {
MP_RETURN_IF_ERROR(ApplyActivation<cv::Vec2f>(tensor_mat, &small_mask_mat));
} else if (tensor_channels == 1) {
RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX !=
options_.activation()); // Requires 2 channels.
if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE ==
options_.activation()) // Pass-through optimization.
tensor_mat.copyTo(small_mask_mat);
else
MP_RETURN_IF_ERROR(ApplyActivation<float>(tensor_mat, &small_mask_mat));
} else {
RET_CHECK_FAIL() << "Unsupported number of tensor channels "
<< tensor_channels;
}
// Send out image as CPU packet.
std::shared_ptr<ImageFrame> mask_frame = std::make_shared<ImageFrame>(
ImageFormat::VEC32F1, output_width, output_height);
std::unique_ptr<Image> output_mask = absl::make_unique<Image>(mask_frame);
auto output_mat = formats::MatView(output_mask.get());
// Upsample small mask into output.
cv::resize(small_mask_mat, *output_mat,
cv::Size(output_width, output_height));
cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp());
#endif // !MEDIAPIPE_DISABLE_OPENCV
return absl::OkStatus();
}
#if !MEDIAPIPE_DISABLE_OPENCV
template <class T>
absl::Status TensorsToSegmentationCalculator::ApplyActivation(
cv::Mat& tensor_mat, cv::Mat* small_mask_mat) {
// Configure activation function.
const int output_layer_index = options_.output_layer_index();
typedef mediapipe::TensorsToSegmentationCalculatorOptions Options;
const auto activation_fn = [&](const cv::Vec2f& mask_value) {
float new_mask_value = 0;
// TODO consider moving switch out of the loop,
// and also avoid float/Vec2f casting.
switch (options_.activation()) {
case Options::NONE: {
new_mask_value = mask_value[0];
break;
}
case Options::SIGMOID: {
const float pixel0 = mask_value[0];
new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0);
break;
}
case Options::SOFTMAX: {
const float pixel0 = mask_value[0];
const float pixel1 = mask_value[1];
const float max_pixel = std::max(pixel0, pixel1);
const float min_pixel = std::min(pixel0, pixel1);
const float softmax_denom =
/*exp(max_pixel - max_pixel)=*/1.0f +
std::exp(min_pixel - max_pixel);
new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) /
softmax_denom;
break;
}
}
return new_mask_value;
};
// Process mask tensor.
for (int i = 0; i < tensor_mat.rows; ++i) {
for (int j = 0; j < tensor_mat.cols; ++j) {
const T& input_pix = tensor_mat.at<T>(i, j);
const float mask_value = activation_fn(input_pix);
small_mask_mat->at<float>(i, j) = mask_value;
}
}
return absl::OkStatus();
}
#endif // !MEDIAPIPE_DISABLE_OPENCV
// Steps:
// 1. receive tensor
// 2. process segmentation tensor into small mask
// 3. upsample small mask into output mask to be same size as input image
absl::Status TensorsToSegmentationCalculator::ProcessGpu(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
std::tuple<int, int, int> hwc, int output_width, int output_height) {
CalculatorContext* cc) {
#if !MEDIAPIPE_DISABLE_GPU
// Get input streams, and dimensions.
const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims));
auto [tensor_height, tensor_width, tensor_channels] = hwc;
int output_width = tensor_width, output_height = tensor_height;
if (cc->Inputs().HasTag(kOutputSizeTag)) {
const auto& size =
cc->Inputs().Tag(kOutputSizeTag).Get<std::pair<int, int>>();
output_width = size.first;
output_height = size.second;
}
// Create initial working mask texture.
#if !(MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31)
@ -519,7 +632,7 @@ void TensorsToSegmentationCalculator::GlRender() {
absl::Status TensorsToSegmentationCalculator::LoadOptions(
CalculatorContext* cc) {
// Get calculator options specified in the graph.
options_ = cc->Options<mediapipe::TensorsToSegmentationCalculatorOptions>();
options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>();
return absl::OkStatus();
}
@ -713,7 +826,7 @@ void main() {
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
// Shader defines.
using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions;
typedef mediapipe::TensorsToSegmentationCalculatorOptions Options;
const std::string output_layer_index =
"\n#define OUTPUT_LAYER_INDEX int(" +
std::to_string(options_.output_layer_index()) + ")";

View File

@ -1,43 +0,0 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_
#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_
#include <memory>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
class TensorsToSegmentationConverter {
public:
virtual ~TensorsToSegmentationConverter() = default;
// Converts tensors to image mask.
// Returns a unique pointer containing the converted image.
// @input_tensors contains the tensors needed to be processed.
// @output_width/height describes output dimensions to reshape the output mask
// into.
virtual absl::StatusOr<std::unique_ptr<Image>> Convert(
const std::vector<Tensor>& input_tensors, int output_width,
int output_height) = 0;
};
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_H_

View File

@ -1,157 +0,0 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter_opencv.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
namespace mediapipe {
namespace {
class OpenCvProcessor : public TensorsToSegmentationConverter {
public:
absl::Status Init(const TensorsToSegmentationCalculatorOptions& options) {
options_ = options;
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<Image>> Convert(
const std::vector<Tensor>& input_tensors, int output_width,
int output_height) override;
private:
template <class T>
absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat);
TensorsToSegmentationCalculatorOptions options_;
};
absl::StatusOr<std::unique_ptr<Image>> OpenCvProcessor::Convert(
const std::vector<Tensor>& input_tensors, int output_width,
int output_height) {
MP_ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims));
auto [tensor_height, tensor_width, tensor_channels] = hwc;
// Create initial working mask.
cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1);
// Wrap input tensor.
auto raw_input_tensor = &input_tensors[0];
auto raw_input_view = raw_input_tensor->GetCpuReadView();
const float* raw_input_data = raw_input_view.buffer<float>();
cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height),
CV_MAKETYPE(CV_32F, tensor_channels),
const_cast<float*>(raw_input_data));
// Process mask tensor and apply activation function.
if (tensor_channels == 2) {
MP_RETURN_IF_ERROR(ApplyActivation<cv::Vec2f>(tensor_mat, &small_mask_mat));
} else if (tensor_channels == 1) {
RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX !=
options_.activation()); // Requires 2 channels.
if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE ==
options_.activation()) // Pass-through optimization.
tensor_mat.copyTo(small_mask_mat);
else
MP_RETURN_IF_ERROR(ApplyActivation<float>(tensor_mat, &small_mask_mat));
} else {
RET_CHECK_FAIL() << "Unsupported number of tensor channels "
<< tensor_channels;
}
// Send out image as CPU packet.
std::shared_ptr<ImageFrame> mask_frame = std::make_shared<ImageFrame>(
ImageFormat::VEC32F1, output_width, output_height);
auto output_mask = std::make_unique<Image>(mask_frame);
auto output_mat = formats::MatView(output_mask.get());
// Upsample small mask into output.
cv::resize(small_mask_mat, *output_mat,
cv::Size(output_width, output_height));
return output_mask;
}
template <class T>
absl::Status OpenCvProcessor::ApplyActivation(cv::Mat& tensor_mat,
cv::Mat* small_mask_mat) {
// Configure activation function.
const int output_layer_index = options_.output_layer_index();
using Options = ::mediapipe::TensorsToSegmentationCalculatorOptions;
const auto activation_fn = [&](const cv::Vec2f& mask_value) {
float new_mask_value = 0;
// TODO consider moving switch out of the loop,
// and also avoid float/Vec2f casting.
switch (options_.activation()) {
case Options::NONE: {
new_mask_value = mask_value[0];
break;
}
case Options::SIGMOID: {
const float pixel0 = mask_value[0];
new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0);
break;
}
case Options::SOFTMAX: {
const float pixel0 = mask_value[0];
const float pixel1 = mask_value[1];
const float max_pixel = std::max(pixel0, pixel1);
const float min_pixel = std::min(pixel0, pixel1);
const float softmax_denom =
/*exp(max_pixel - max_pixel)=*/1.0f +
std::exp(min_pixel - max_pixel);
new_mask_value = std::exp(mask_value[output_layer_index] - max_pixel) /
softmax_denom;
break;
}
}
return new_mask_value;
};
// Process mask tensor.
for (int i = 0; i < tensor_mat.rows; ++i) {
for (int j = 0; j < tensor_mat.cols; ++j) {
const T& input_pix = tensor_mat.at<T>(i, j);
const float mask_value = activation_fn(input_pix);
small_mask_mat->at<float>(i, j) = mask_value;
}
}
return absl::OkStatus();
}
} // namespace
absl::StatusOr<std::unique_ptr<TensorsToSegmentationConverter>>
CreateOpenCvConverter(const TensorsToSegmentationCalculatorOptions& options) {
auto converter = std::make_unique<OpenCvProcessor>();
MP_RETURN_IF_ERROR(converter->Init(options));
return converter;
}
} // namespace mediapipe

View File

@ -1,31 +0,0 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_
#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_
#include <memory>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_segmentation_converter.h"
namespace mediapipe {
// Creates OpenCV tensors-to-segmentation converter.
absl::StatusOr<std::unique_ptr<TensorsToSegmentationConverter>>
CreateOpenCvConverter(
const mediapipe::TensorsToSegmentationCalculatorOptions& options);
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_CONVERTER_OPENCV_H_

View File

@ -1,52 +0,0 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h"
#include <tuple>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
int NumGroups(int size, int group_size) {
return (size + group_size - 1) / group_size;
}
bool CanUseGpu() {
#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED
// TODO: Configure GPU usage policy in individual calculators.
constexpr bool kAllowGpuProcessing = true;
return kAllowGpuProcessing;
#else
return false;
#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED
}
absl::StatusOr<std::tuple<int, int, int>> GetHwcFromDims(
const std::vector<int>& dims) {
if (dims.size() == 3) {
return std::make_tuple(dims[0], dims[1], dims[2]);
} else if (dims.size() == 4) {
// BHWC format check B == 1
RET_CHECK_EQ(dims[0], 1) << "Expected batch to be 1 for BHWC heatmap";
return std::make_tuple(dims[1], dims[2], dims[3]);
} else {
RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size();
}
}
} // namespace mediapipe

View File

@ -1,34 +0,0 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_
#define MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_
#include <tuple>
#include <vector>
#include "absl/status/statusor.h"
namespace mediapipe {
// Commonly used to compute the number of blocks to launch in a kernel.
int NumGroups(const int size, const int group_size); // NOLINT
bool CanUseGpu();
absl::StatusOr<std::tuple<int, int, int>> GetHwcFromDims(
const std::vector<int>& dims);
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_TENSOR_TENSORS_TO_SEGMENTATION_UTILS_H_

View File

@ -1,63 +0,0 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensor/tensors_to_segmentation_utils.h"
#include <tuple>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
using ::testing::HasSubstr;
TEST(TensorsToSegmentationUtilsTest, NumGroupsWorksProperly) {
EXPECT_EQ(NumGroups(13, 4), 4);
EXPECT_EQ(NumGroups(4, 13), 1);
}
TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsWorksProperly) {
std::vector<int> dims_3 = {2, 3, 4};
absl::StatusOr<std::tuple<int, int, int>> result_1 = GetHwcFromDims(dims_3);
MP_ASSERT_OK(result_1);
EXPECT_EQ(result_1.value(), (std::make_tuple(2, 3, 4)));
std::vector<int> dims_4 = {1, 3, 4, 5};
absl::StatusOr<std::tuple<int, int, int>> result_2 = GetHwcFromDims(dims_4);
MP_ASSERT_OK(result_2);
EXPECT_EQ(result_2.value(), (std::make_tuple(3, 4, 5)));
}
TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsBatchCheckFail) {
std::vector<int> dims_4 = {2, 3, 4, 5};
absl::StatusOr<std::tuple<int, int, int>> result = GetHwcFromDims(dims_4);
EXPECT_FALSE(result.ok());
EXPECT_THAT(result.status().message(),
HasSubstr("Expected batch to be 1 for BHWC heatmap"));
}
TEST(TensorsToSegmentationUtilsTest, GetHwcFromDimsInvalidShape) {
std::vector<int> dims_5 = {1, 2, 3, 4, 5};
absl::StatusOr<std::tuple<int, int, int>> result = GetHwcFromDims(dims_5);
EXPECT_FALSE(result.ok());
EXPECT_THAT(result.status().message(),
HasSubstr("Invalid shape for segmentation tensor"));
}
} // namespace
} // namespace mediapipe