Add option for nearest neighbor interpolation.

PiperOrigin-RevId: 564786213
This commit is contained in:
MediaPipe Team 2023-09-12 11:39:16 -07:00 committed by Copybara-Service
parent 5daed78844
commit 4ba1dadf92
6 changed files with 294 additions and 35 deletions

View File

@ -263,7 +263,10 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_base_hdr",
"//mediapipe/gpu:scale_mode_cc_proto", "//mediapipe/gpu:scale_mode_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -276,6 +279,36 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "image_transformation_calculator_test",
srcs = ["image_transformation_calculator_test.cc"],
data = ["//mediapipe/calculators/image/testdata:test_images"],
tags = [
"desktop_only_test",
],
deps = [
":image_transformation_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
"//third_party:opencv",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
cc_library( cc_library(
name = "image_cropping_calculator", name = "image_cropping_calculator",
srcs = ["image_cropping_calculator.cc"], srcs = ["image_cropping_calculator.cc"],

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/status/status.h"
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h" #include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/calculators/image/rotation_mode.pb.h" #include "mediapipe/calculators/image/rotation_mode.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -24,6 +25,7 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/scale_mode.pb.h" #include "mediapipe/gpu/scale_mode.pb.h"
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -60,42 +62,42 @@ constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) {
switch (rotation) { switch (rotation) {
case mediapipe::RotationMode_Mode_UNKNOWN: case mediapipe::RotationMode::UNKNOWN:
case mediapipe::RotationMode_Mode_ROTATION_0: case mediapipe::RotationMode::ROTATION_0:
return 0; return 0;
case mediapipe::RotationMode_Mode_ROTATION_90: case mediapipe::RotationMode::ROTATION_90:
return 90; return 90;
case mediapipe::RotationMode_Mode_ROTATION_180: case mediapipe::RotationMode::ROTATION_180:
return 180; return 180;
case mediapipe::RotationMode_Mode_ROTATION_270: case mediapipe::RotationMode::ROTATION_270:
return 270; return 270;
} }
} }
mediapipe::RotationMode_Mode DegreesToRotationMode(int degrees) { mediapipe::RotationMode_Mode DegreesToRotationMode(int degrees) {
switch (degrees) { switch (degrees) {
case 0: case 0:
return mediapipe::RotationMode_Mode_ROTATION_0; return mediapipe::RotationMode::ROTATION_0;
case 90: case 90:
return mediapipe::RotationMode_Mode_ROTATION_90; return mediapipe::RotationMode::ROTATION_90;
case 180: case 180:
return mediapipe::RotationMode_Mode_ROTATION_180; return mediapipe::RotationMode::ROTATION_180;
case 270: case 270:
return mediapipe::RotationMode_Mode_ROTATION_270; return mediapipe::RotationMode::ROTATION_270;
default: default:
return mediapipe::RotationMode_Mode_UNKNOWN; return mediapipe::RotationMode::UNKNOWN;
} }
} }
mediapipe::ScaleMode_Mode ParseScaleMode( mediapipe::ScaleMode_Mode ParseScaleMode(
mediapipe::ScaleMode_Mode scale_mode, mediapipe::ScaleMode_Mode scale_mode,
mediapipe::ScaleMode_Mode default_mode) { mediapipe::ScaleMode_Mode default_mode) {
switch (scale_mode) { switch (scale_mode) {
case mediapipe::ScaleMode_Mode_DEFAULT: case mediapipe::ScaleMode::DEFAULT:
return default_mode; return default_mode;
case mediapipe::ScaleMode_Mode_STRETCH: case mediapipe::ScaleMode::STRETCH:
return scale_mode; return scale_mode;
case mediapipe::ScaleMode_Mode_FIT: case mediapipe::ScaleMode::FIT:
return scale_mode; return scale_mode;
case mediapipe::ScaleMode_Mode_FILL_AND_CROP: case mediapipe::ScaleMode::FILL_AND_CROP:
return scale_mode; return scale_mode;
default: default:
return default_mode; return default_mode;
@ -208,6 +210,8 @@ class ImageTransformationCalculator : public CalculatorBase {
bool use_gpu_ = false; bool use_gpu_ = false;
cv::Scalar padding_color_; cv::Scalar padding_color_;
ImageTransformationCalculatorOptions::InterpolationMode interpolation_mode_;
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
GlCalculatorHelper gpu_helper_; GlCalculatorHelper gpu_helper_;
std::unique_ptr<QuadRenderer> rgb_renderer_; std::unique_ptr<QuadRenderer> rgb_renderer_;
@ -343,6 +347,11 @@ absl::Status ImageTransformationCalculator::Open(CalculatorContext* cc) {
options_.padding_color().green(), options_.padding_color().green(),
options_.padding_color().blue()); options_.padding_color().blue());
interpolation_mode_ = options_.interpolation_mode();
if (options_.interpolation_mode() ==
ImageTransformationCalculatorOptions::DEFAULT) {
interpolation_mode_ = ImageTransformationCalculatorOptions::LINEAR;
}
if (use_gpu_) { if (use_gpu_) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
// Let the helper access the GL context information. // Let the helper access the GL context information.
@ -457,26 +466,48 @@ absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) {
ComputeOutputDimensions(input_width, input_height, &output_width, ComputeOutputDimensions(input_width, input_height, &output_width,
&output_height); &output_height);
int opencv_interpolation_mode = cv::INTER_LINEAR;
if (output_width_ > 0 && output_height_ > 0) { if (output_width_ > 0 && output_height_ > 0) {
cv::Mat scaled_mat; cv::Mat scaled_mat;
if (scale_mode_ == mediapipe::ScaleMode_Mode_STRETCH) { if (scale_mode_ == mediapipe::ScaleMode::STRETCH) {
int scale_flag = if (interpolation_mode_ == ImageTransformationCalculatorOptions::LINEAR) {
input_mat.cols > output_width_ && input_mat.rows > output_height_ // Use INTER_AREA for downscaling if interpolation mode is set to
? cv::INTER_AREA // LINEAR.
: cv::INTER_LINEAR; if (input_mat.cols > output_width_ && input_mat.rows > output_height_) {
opencv_interpolation_mode = cv::INTER_AREA;
} else {
opencv_interpolation_mode = cv::INTER_LINEAR;
}
} else {
opencv_interpolation_mode = cv::INTER_NEAREST;
}
cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_), cv::resize(input_mat, scaled_mat, cv::Size(output_width_, output_height_),
0, 0, scale_flag); 0, 0, opencv_interpolation_mode);
} else { } else {
const float scale = const float scale =
std::min(static_cast<float>(output_width_) / input_width, std::min(static_cast<float>(output_width_) / input_width,
static_cast<float>(output_height_) / input_height); static_cast<float>(output_height_) / input_height);
const int target_width = std::round(input_width * scale); const int target_width = std::round(input_width * scale);
const int target_height = std::round(input_height * scale); const int target_height = std::round(input_height * scale);
int scale_flag = scale < 1.0f ? cv::INTER_AREA : cv::INTER_LINEAR;
if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { if (interpolation_mode_ == ImageTransformationCalculatorOptions::LINEAR) {
// Use INTER_AREA for downscaling if interpolation mode is set to
// LINEAR.
if (scale < 1.0f) {
opencv_interpolation_mode = cv::INTER_AREA;
} else {
opencv_interpolation_mode = cv::INTER_LINEAR;
}
} else {
opencv_interpolation_mode = cv::INTER_NEAREST;
}
if (scale_mode_ == mediapipe::ScaleMode::FIT) {
cv::Mat intermediate_mat; cv::Mat intermediate_mat;
cv::resize(input_mat, intermediate_mat, cv::resize(input_mat, intermediate_mat,
cv::Size(target_width, target_height), 0, 0, scale_flag); cv::Size(target_width, target_height), 0, 0,
opencv_interpolation_mode);
const int top = (output_height_ - target_height) / 2; const int top = (output_height_ - target_height) / 2;
const int bottom = output_height_ - target_height - top; const int bottom = output_height_ - target_height - top;
const int left = (output_width_ - target_width) / 2; const int left = (output_width_ - target_width) / 2;
@ -488,7 +519,7 @@ absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) {
padding_color_); padding_color_);
} else { } else {
cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height), cv::resize(input_mat, scaled_mat, cv::Size(target_width, target_height),
0, 0, scale_flag); 0, 0, opencv_interpolation_mode);
output_width = target_width; output_width = target_width;
output_height = target_height; output_height = target_height;
} }
@ -514,17 +545,17 @@ absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) {
cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size); cv::warpAffine(input_mat, rotated_mat, rotation_mat, rotated_size);
} else { } else {
switch (rotation_) { switch (rotation_) {
case mediapipe::RotationMode_Mode_UNKNOWN: case mediapipe::RotationMode::UNKNOWN:
case mediapipe::RotationMode_Mode_ROTATION_0: case mediapipe::RotationMode::ROTATION_0:
rotated_mat = input_mat; rotated_mat = input_mat;
break; break;
case mediapipe::RotationMode_Mode_ROTATION_90: case mediapipe::RotationMode::ROTATION_90:
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE); cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_COUNTERCLOCKWISE);
break; break;
case mediapipe::RotationMode_Mode_ROTATION_180: case mediapipe::RotationMode::ROTATION_180:
cv::rotate(input_mat, rotated_mat, cv::ROTATE_180); cv::rotate(input_mat, rotated_mat, cv::ROTATE_180);
break; break;
case mediapipe::RotationMode_Mode_ROTATION_270: case mediapipe::RotationMode::ROTATION_270:
cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE); cv::rotate(input_mat, rotated_mat, cv::ROTATE_90_CLOCKWISE);
break; break;
} }
@ -561,7 +592,7 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) {
ComputeOutputDimensions(input_width, input_height, &output_width, ComputeOutputDimensions(input_width, input_height, &output_width,
&output_height); &output_height);
if (scale_mode_ == mediapipe::ScaleMode_Mode_FILL_AND_CROP) { if (scale_mode_ == mediapipe::ScaleMode::FILL_AND_CROP) {
const float scale = const float scale =
std::min(static_cast<float>(output_width_) / input_width, std::min(static_cast<float>(output_width_) / input_width,
static_cast<float>(output_height_) / input_height); static_cast<float>(output_height_) / input_height);
@ -628,11 +659,20 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) {
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), src1.name()); glBindTexture(src1.target(), src1.name());
if (interpolation_mode_ == ImageTransformationCalculatorOptions::NEAREST) {
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
}
MP_RETURN_IF_ERROR(renderer->GlRender( MP_RETURN_IF_ERROR(renderer->GlRender(
src1.width(), src1.height(), dst.width(), dst.height(), scale_mode, src1.width(), src1.height(), dst.width(), dst.height(), scale_mode,
rotation, flip_horizontally_, flip_vertically_, rotation, flip_horizontally_, flip_vertically_,
/*flip_texture=*/false)); /*flip_texture=*/false));
// Reset interpolation modes to MediaPipe defaults.
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), 0); glBindTexture(src1.target(), 0);
@ -652,8 +692,8 @@ void ImageTransformationCalculator::ComputeOutputDimensions(
if (output_width_ > 0 && output_height_ > 0) { if (output_width_ > 0 && output_height_ > 0) {
*output_width = output_width_; *output_width = output_width_;
*output_height = output_height_; *output_height = output_height_;
} else if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 || } else if (rotation_ == mediapipe::RotationMode::ROTATION_90 ||
rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) { rotation_ == mediapipe::RotationMode::ROTATION_270) {
*output_width = input_height; *output_width = input_height;
*output_height = input_width; *output_height = input_width;
} else { } else {
@ -666,9 +706,9 @@ void ImageTransformationCalculator::ComputeOutputLetterboxPadding(
int input_width, int input_height, int output_width, int output_height, int input_width, int input_height, int output_width, int output_height,
std::array<float, 4>* padding) { std::array<float, 4>* padding) {
padding->fill(0.f); padding->fill(0.f);
if (scale_mode_ == mediapipe::ScaleMode_Mode_FIT) { if (scale_mode_ == mediapipe::ScaleMode::FIT) {
if (rotation_ == mediapipe::RotationMode_Mode_ROTATION_90 || if (rotation_ == mediapipe::RotationMode::ROTATION_90 ||
rotation_ == mediapipe::RotationMode_Mode_ROTATION_270) { rotation_ == mediapipe::RotationMode::ROTATION_270) {
std::swap(input_width, input_height); std::swap(input_width, input_height);
} }
const float input_aspect_ratio = const float input_aspect_ratio =

View File

@ -54,4 +54,15 @@ message ImageTransformationCalculatorOptions {
// The color for the padding. This option is only used when the scale mode is // The color for the padding. This option is only used when the scale mode is
// FIT. Default is black. This is for CPU only. // FIT. Default is black. This is for CPU only.
optional Color padding_color = 8; optional Color padding_color = 8;
// Interpolation method to use. Note that on CPU when LINEAR is specified,
// INTER_LINEAR is used for upscaling and INTER_AREA is used for downscaling.
enum InterpolationMode {
DEFAULT = 0;
LINEAR = 1;
NEAREST = 2;
}
// Mode DEFAULT will use LINEAR interpolation.
optional InterpolationMode interpolation_mode = 9;
} }

View File

@ -0,0 +1,174 @@
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/flags/flag.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/googletest.h"
#include "third_party/OpenCV/core/mat.hpp"
namespace mediapipe {
namespace {
absl::flat_hash_set<int> computeUniqueValues(const cv::Mat& mat) {
// Compute the unique values in cv::Mat
absl::flat_hash_set<int> unique_values;
for (int i = 0; i < mat.rows; i++) {
for (int j = 0; j < mat.cols; j++) {
unique_values.insert(mat.at<unsigned char>(i, j));
}
}
return unique_values;
}
TEST(ImageTransformationCalculatorTest, NearestNeighborResizing) {
cv::Mat input_mat;
cv::cvtColor(cv::imread(file::JoinPath("./",
"/mediapipe/calculators/"
"image/testdata/binary_mask.png")),
input_mat, cv::COLOR_BGR2GRAY);
Packet input_image_packet = MakePacket<ImageFrame>(
ImageFormat::GRAY8, input_mat.size().width, input_mat.size().height);
input_mat.copyTo(formats::MatView(&(input_image_packet.Get<ImageFrame>())));
std::vector<std::pair<int, int>> output_dims{
{256, 333}, {512, 512}, {1024, 1024}};
for (auto& output_dim : output_dims) {
Packet input_output_dim_packet =
MakePacket<std::pair<int, int>>(output_dim);
std::vector<std::string> scale_modes{"FIT", "STRETCH"};
for (const auto& scale_mode : scale_modes) {
CalculatorGraphConfig::Node node_config =
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
absl::Substitute(R"(
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:input_image"
input_stream: "OUTPUT_DIMENSIONS:image_size"
output_stream: "IMAGE:output_image"
options: {
[mediapipe.ImageTransformationCalculatorOptions.ext]: {
scale_mode: $0
interpolation_mode: NEAREST
}
})",
scale_mode));
CalculatorRunner runner(node_config);
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_image_packet.At(Timestamp(0)));
runner.MutableInputs()
->Tag("OUTPUT_DIMENSIONS")
.packets.push_back(input_output_dim_packet.At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(outputs.NumEntries(), 1);
const std::vector<Packet>& packets = outputs.Tag("IMAGE").packets;
ASSERT_EQ(packets.size(), 1);
const auto& result = packets[0].Get<ImageFrame>();
ASSERT_EQ(output_dim.first, result.Width());
ASSERT_EQ(output_dim.second, result.Height());
auto unique_input_values = computeUniqueValues(input_mat);
auto unique_output_values =
computeUniqueValues(formats::MatView(&result));
EXPECT_THAT(unique_input_values,
::testing::ContainerEq(unique_output_values));
}
}
}
TEST(ImageTransformationCalculatorTest, NearestNeighborResizingGpu) {
cv::Mat input_mat;
cv::cvtColor(cv::imread(file::JoinPath("./",
"/mediapipe/calculators/"
"image/testdata/binary_mask.png")),
input_mat, cv::COLOR_BGR2RGBA);
std::vector<std::pair<int, int>> output_dims{
{256, 333}, {512, 512}, {1024, 1024}};
for (auto& output_dim : output_dims) {
std::vector<std::string> scale_modes{"FIT"}; //, "STRETCH"};
for (const auto& scale_mode : scale_modes) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "input_image"
input_stream: "image_size"
output_stream: "output_image"
node {
calculator: "ImageFrameToGpuBufferCalculator"
input_stream: "input_image"
output_stream: "input_image_gpu"
}
node {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE_GPU:input_image_gpu"
input_stream: "OUTPUT_DIMENSIONS:image_size"
output_stream: "IMAGE_GPU:output_image_gpu"
options: {
[mediapipe.ImageTransformationCalculatorOptions.ext]: {
scale_mode: $0
interpolation_mode: NEAREST
}
}
}
node {
calculator: "GpuBufferToImageFrameCalculator"
input_stream: "output_image_gpu"
output_stream: "output_image"
})",
scale_mode));
ImageFrame input_image(ImageFormat::SRGBA, input_mat.size().width,
input_mat.size().height);
input_mat.copyTo(formats::MatView(&input_image));
std::vector<Packet> output_image_packets;
tool::AddVectorSink("output_image", &graph_config, &output_image_packets);
CalculatorGraph graph(graph_config);
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image",
MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"image_size",
MakePacket<std::pair<int, int>>(output_dim).At(Timestamp(0))));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_THAT(output_image_packets, testing::SizeIs(1));
const auto& output_image = output_image_packets[0].Get<ImageFrame>();
ASSERT_EQ(output_dim.first, output_image.Width());
ASSERT_EQ(output_dim.second, output_image.Height());
auto unique_input_values = computeUniqueValues(input_mat);
auto unique_output_values =
computeUniqueValues(formats::MatView(&output_image));
EXPECT_THAT(unique_input_values,
::testing::ContainerEq(unique_output_values));
}
}
}
} // namespace
} // namespace mediapipe

View File

@ -18,6 +18,7 @@ licenses(["notice"])
filegroup( filegroup(
name = "test_images", name = "test_images",
srcs = [ srcs = [
"binary_mask.png",
"dino.jpg", "dino.jpg",
"dino_quality_50.jpg", "dino_quality_50.jpg",
"dino_quality_80.jpg", "dino_quality_80.jpg",

Binary file not shown.

After

Width:  |  Height:  |  Size: 771 B