Internal change

PiperOrigin-RevId: 518330697
This commit is contained in:
MediaPipe Team 2023-03-21 11:26:28 -07:00 committed by Copybara-Service
parent 2be66e8eb0
commit c2a3e99545
3 changed files with 193 additions and 24 deletions

View File

@ -156,6 +156,7 @@ cc_library(
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector",
"//mediapipe/framework/port:opencv_imgproc",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
@ -168,6 +169,25 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "set_alpha_calculator_test",
srcs = ["set_alpha_calculator_test.cc"],
deps = [
":set_alpha_calculator",
":set_alpha_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "bilateral_filter_calculator",
srcs = ["bilateral_filter_calculator.cc"],

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h"
@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
// must be uchar.
template <typename AlphaType>
absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat,
cv::Mat& output_mat) {
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
RET_CHECK_EQ(input_mat.rows, output_mat.rows);
RET_CHECK_EQ(input_mat.cols, output_mat.cols);
absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) {
RET_CHECK_EQ(output_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(output_mat.cols, alpha_mat.cols);
for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
const int alpha_idx = j * alpha_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
if constexpr (std::is_same<AlphaType, uchar>::value) {
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
} else {
@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup source image
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame);
const cv::Mat input_mat = formats::MatView(&input_frame);
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
}
@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup destination image
auto output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
cv::Mat output_mat = formats::MatView(output_frame.get());
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
// Setup alpha image and Update image in CPU.
// Copy rgb part of the image in CPU
if (input_mat.channels() == 3) {
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA);
} else {
input_mat.copyTo(output_mat);
}
// Setup alpha image in CPU.
if (use_alpha_mask) {
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask);
cv::Mat alpha_mat = formats::MatView(&alpha_mask);
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
if (alpha_is_float) {
MP_RETURN_IF_ERROR(
MergeRGBA8Image<float>(input_mat, alpha_mat, output_mat));
MP_RETURN_IF_ERROR(CopyAlphaImage<float>(alpha_mat, output_mat));
} else {
MP_RETURN_IF_ERROR(
MergeRGBA8Image<uchar>(input_mat, alpha_mat, output_mat));
MP_RETURN_IF_ERROR(CopyAlphaImage<uchar>(alpha_mat, output_mat));
}
} else {
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
out_ptr[out_idx + 3] = alpha_value; // use value from options
}
}

View File

@ -0,0 +1,156 @@
#include <cstdint>
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "testing/base/public/benchmark.h"
namespace mediapipe {
namespace {
constexpr int input_width = 100;
constexpr int input_height = 100;
std::unique_ptr<ImageFrame> GetInputFrame(int width, int height, int channel) {
const int total_size = width * height * channel;
ImageFormat::Format image_format;
if (channel == 4) {
image_format = ImageFormat::SRGBA;
} else if (channel == 3) {
image_format = ImageFormat::SRGB;
} else {
image_format = ImageFormat::GRAY8;
}
auto input_frame = std::make_unique<ImageFrame>(image_format, width, height,
/*alignment_boundary =*/1);
for (int i = 0; i < total_size; ++i) {
input_frame->MutablePixelData()[i] = i % 256;
}
return input_frame;
}
// Test SetAlphaCalculator with RGB IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgb) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 3);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
// Test SetAlphaCalculator with RGBA IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgba) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 4);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 4);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
static void BM_SetAlpha3ChannelImage(benchmark::State& state) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
for (const auto _ : state) {
MP_ASSERT_OK(runner.Run());
}
}
BENCHMARK(BM_SetAlpha3ChannelImage);
} // namespace
} // namespace mediapipe