Internal change

PiperOrigin-RevId: 516349788
This commit is contained in:
Chris McClanahan 2023-03-13 16:07:36 -07:00 committed by Copybara-Service
parent cb4b0ea93d
commit 1b4a835be0
10 changed files with 256 additions and 60 deletions

View File

@ -748,6 +748,7 @@ cc_test(
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
],
tags = ["desktop_only_test"],

View File

@ -29,6 +29,9 @@ class AffineTransformation {
// pixels will be calculated.
enum class BorderMode { kZero, kReplicate };
// Pixel sampling interpolation method.
enum class Interpolation { kLinear, kCubic };
struct Size {
int width;
int height;

View File

@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
std::unique_ptr<GpuBuffer>> {
public:
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
GpuOrigin::Mode gpu_origin)
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {}
GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation)
: gl_helper_(gl_helper),
gpu_origin_(gpu_origin),
interpolation_(interpolation) {}
absl::Status Init() {
return gl_helper_->RunInGlContext([this]() -> absl::Status {
const GLint attr_location[kNumAttributes] = {
@ -103,28 +106,83 @@ class GlTextureWarpAffineRunner
}
)";
// TODO Move bicubic code to common shared place.
constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;
DEFAULT_PRECISION(highp, float)
#ifdef GL_ES
#define fragColor gl_FragColor
#else
out vec4 fragColor;
#endif // defined(GL_ES);
in vec2 sample_coordinate;
uniform sampler2D input_texture;
uniform vec2 input_size;
void main() {
vec4 color = texture2D(input_texture, sample_coordinate);
#ifdef CUSTOM_ZERO_BORDER_MODE
float out_of_bounds =
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
fragColor = color;
}
)";
#ifdef GL_ES
#define fragColor gl_FragColor
#else
out vec4 fragColor;
#endif // defined(GL_ES);
#ifdef CUBIC_INTERPOLATION
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
const vec2 halve = vec2(0.5,0.5);
const vec2 one = vec2(1.0,1.0);
const vec2 two = vec2(2.0,2.0);
const vec2 three = vec2(3.0,3.0);
const vec2 six = vec2(6.0,6.0);
// Calculate the fraction and integer.
tex_coord = tex_coord * tex_size - halve;
vec2 frac = fract(tex_coord);
vec2 index = tex_coord - frac + halve;
// Calculate weights for Catmull-Rom filter.
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
vec2 w3 = frac * frac * (-halve + halve * frac);
// Calculate weights to take advantage of bilinear texture lookup.
vec2 w12 = w1 + w2;
vec2 offset12 = w2 / (w1 + w2);
vec2 index_tl = index - one;
vec2 index_br = index + two;
vec2 index_eq = index + offset12;
index_tl /= tex_size;
index_br /= tex_size;
index_eq /= tex_size;
// 9 texture lookup and linear blending.
vec4 color = vec4(0.0);
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
return color;
}
#else
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
return texture2D(tex, tex_coord);
}
#endif // defined(CUBIC_INTERPOLATION)
void main() {
vec4 color = sample(input_texture, sample_coordinate, input_size);
#ifdef CUSTOM_ZERO_BORDER_MODE
float out_of_bounds =
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
fragColor = color;
}
)";
// Create program and set parameters.
auto create_fn = [&](const std::string& vs,
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
glUseProgram(program);
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
return Program{.id = program, .matrix_id = matrix_id};
GLint size_id = glGetUniformLocation(program, "input_size");
return Program{
.id = program, .matrix_id = matrix_id, .size_id = size_id};
};
const std::string vert_src =
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
const std::string frag_src = absl::StrCat(
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader);
std::string interpolation_def;
switch (interpolation_) {
case AffineTransformation::Interpolation::kCubic:
interpolation_def = R"(
#define CUBIC_INTERPOLATION
)";
break;
case AffineTransformation::Interpolation::kLinear:
break;
}
const std::string frag_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
interpolation_def, kFragShader);
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
std::string custom_zero_border_mode_def = R"(
#define CUSTOM_ZERO_BORDER_MODE
)";
const std::string frag_custom_zero_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, kFragShader);
const std::string frag_custom_zero_src = absl::StrCat(
mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, interpolation_def, kFragShader);
return create_fn(vert_src, frag_custom_zero_src);
};
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
}
glUseProgram(program->id);
// uniforms
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
// @matrix describes affine transformation in terms of TOP LEFT origin, so
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
eigen_mat.transposeInPlace();
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
glUniform2f(program->size_id, texture.width(), texture.height());
}
// vao
glBindVertexArray(vao_);
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
struct Program {
GLuint id;
GLint matrix_id;
GLint size_id;
};
std::shared_ptr<GlCalculatorHelper> gl_helper_;
GpuOrigin::Mode gpu_origin_;
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
Program program_;
std::optional<Program> program_custom_zero_;
GLuint framebuffer_ = 0;
AffineTransformation::Interpolation interpolation_ =
AffineTransformation::Interpolation::kLinear;
};
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
absl::StatusOr<std::unique_ptr<
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
CreateAffineTransformationGlRunner(
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) {
auto runner =
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin);
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation) {
auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
gl_helper, gpu_origin, interpolation);
MP_RETURN_IF_ERROR(runner->Init());
return runner;
}

View File

@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
CreateAffineTransformationGlRunner(
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
mediapipe::GpuOrigin::Mode gpu_origin);
mediapipe::GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe

View File

@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
}
}
int GetInterpolationForOpenCv(
AffineTransformation::Interpolation interpolation) {
switch (interpolation) {
case AffineTransformation::Interpolation::kLinear:
return cv::INTER_LINEAR;
case AffineTransformation::Interpolation::kCubic:
return cv::INTER_CUBIC;
}
}
class OpenCvRunner
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
public:
OpenCvRunner(AffineTransformation::Interpolation interpolation)
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
absl::StatusOr<ImageFrame> Run(
const ImageFrame& input, const std::array<float, 16>& matrix,
const AffineTransformation::Size& size,
@ -142,19 +155,23 @@ class OpenCvRunner
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
cv::Size(out_mat.cols, out_mat.rows),
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP,
/*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
GetBorderModeForOpenCv(border_mode));
return out_image;
}
private:
int interpolation_ = cv::INTER_LINEAR;
};
} // namespace
absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner() {
return absl::make_unique<OpenCvRunner>();
CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation) {
return absl::make_unique<OpenCvRunner>(interpolation);
}
} // namespace mediapipe

View File

@ -25,7 +25,8 @@ namespace mediapipe {
absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner();
CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe

View File

@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
}
}
AffineTransformation::Interpolation GetInterpolation(
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
switch (interpolation) {
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
return AffineTransformation::Interpolation::kLinear;
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
return AffineTransformation::Interpolation::kCubic;
}
}
template <typename ImageT>
class WarpAffineRunnerHolder {};
@ -61,16 +72,22 @@ template <>
class WarpAffineRunnerHolder<ImageFrame> {
public:
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
absl::Status Open(CalculatorContext* cc) {
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return absl::OkStatus();
}
absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) {
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner());
ASSIGN_OR_RETURN(runner_,
CreateAffineTransformationOpenCvRunner(interpolation_));
}
return runner_.get();
}
private:
std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
};
#endif // !MEDIAPIPE_DISABLE_OPENCV
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
gpu_origin_ =
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return gl_helper_->Open(cc);
}
absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) {
ASSIGN_OR_RETURN(
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_));
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
gl_helper_, gpu_origin_, interpolation_));
}
return runner_.get();
}
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
mediapipe::GpuOrigin::Mode gpu_origin_;
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
};
#endif // !MEDIAPIPE_DISABLE_GPU

View File

@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
BORDER_REPLICATE = 2;
}
// Pixel sampling interpolation methods. See @interpolation.
enum Interpolation {
INTER_UNSPECIFIED = 0;
INTER_LINEAR = 1;
INTER_CUBIC = 2;
}
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
// to be flipped vertically as tensors are expected to start at top.
// (DEFAULT or unset interpreted as CONVENTIONAL.)
optional GpuOrigin.Mode gpu_origin = 2;
// Sampling method for neighboring pixels.
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
optional Interpolation interpolation = 3;
}

View File

@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
const cv::Mat& input, cv::Mat expected_result,
float similarity_threshold, std::array<float, 16> matrix,
int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) {
std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
std::string border_mode_str;
if (border_mode) {
switch (*border_mode) {
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
break;
}
}
std::string interpolation_str;
if (interpolation) {
switch (*interpolation) {
case AffineTransformation::Interpolation::kLinear:
interpolation_str = "interpolation: INTER_LINEAR";
break;
case AffineTransformation::Interpolation::kCubic:
interpolation_str = "interpolation: INTER_CUBIC";
break;
}
}
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(graph_text, /*$0=*/border_mode_str));
absl::Substitute(graph_text, /*$0=*/border_mode_str,
/*$1=*/interpolation_str));
std::vector<Packet> output_packets;
tool::AddVectorSink("output_image", &graph_config, &output_packets);
@ -132,7 +145,8 @@ struct SimilarityConfig {
void RunTest(cv::Mat input, cv::Mat expected_result,
const SimilarityConfig& similarity, std::array<float, 16> matrix,
int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) {
std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
RunTest(R"(
input_stream: "input_image"
input_stream: "output_size"
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
}
}
}
)",
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
RunTest(R"(
input_stream: "input_image"
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
}
}
}
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
}
)",
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
matrix, out_width, out_height, border_mode);
matrix, out_width, out_height, border_mode, interpolation);
RunTest(R"(
input_stream: "input_image"
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT
}
}
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
}
)",
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
RunTest(R"(
input_stream: "input_image"
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT
}
}
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
}
)",
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
matrix, out_width, out_height, border_mode);
matrix, out_width, out_height, border_mode, interpolation);
}
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
int out_height = 256;
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
mediapipe::NormalizedRect roi;
roi.set_x_center(0.65f);
roi.set_y_center(0.4f);
roi.set_width(0.5f);
roi.set_height(0.5f);
roi.set_rotation(M_PI * -45.0f / 180.0f);
auto input = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/input.jpg");
auto expected_output = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
int out_width = 256;
int out_height = 256;
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation =
AffineTransformation::Interpolation::kCubic;
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRect) {
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
int out_height = 128;
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, NoOp) {
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
} // namespace

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB