Internal change
PiperOrigin-RevId: 516349788
This commit is contained in:
parent
cb4b0ea93d
commit
1b4a835be0
|
@ -748,6 +748,7 @@ cc_test(
|
|||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
|
||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
||||
],
|
||||
tags = ["desktop_only_test"],
|
||||
|
|
|
@ -29,6 +29,9 @@ class AffineTransformation {
|
|||
// pixels will be calculated.
|
||||
enum class BorderMode { kZero, kReplicate };
|
||||
|
||||
// Pixel sampling interpolation method.
|
||||
enum class Interpolation { kLinear, kCubic };
|
||||
|
||||
struct Size {
|
||||
int width;
|
||||
int height;
|
||||
|
|
|
@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
|
|||
std::unique_ptr<GpuBuffer>> {
|
||||
public:
|
||||
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
||||
GpuOrigin::Mode gpu_origin)
|
||||
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {}
|
||||
GpuOrigin::Mode gpu_origin,
|
||||
AffineTransformation::Interpolation interpolation)
|
||||
: gl_helper_(gl_helper),
|
||||
gpu_origin_(gpu_origin),
|
||||
interpolation_(interpolation) {}
|
||||
absl::Status Init() {
|
||||
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
||||
const GLint attr_location[kNumAttributes] = {
|
||||
|
@ -103,28 +106,83 @@ class GlTextureWarpAffineRunner
|
|||
}
|
||||
)";
|
||||
|
||||
// TODO Move bicubic code to common shared place.
|
||||
constexpr GLchar kFragShader[] = R"(
|
||||
DEFAULT_PRECISION(highp, float)
|
||||
in vec2 sample_coordinate;
|
||||
uniform sampler2D input_texture;
|
||||
DEFAULT_PRECISION(highp, float)
|
||||
|
||||
#ifdef GL_ES
|
||||
#define fragColor gl_FragColor
|
||||
#else
|
||||
out vec4 fragColor;
|
||||
#endif // defined(GL_ES);
|
||||
in vec2 sample_coordinate;
|
||||
uniform sampler2D input_texture;
|
||||
uniform vec2 input_size;
|
||||
|
||||
void main() {
|
||||
vec4 color = texture2D(input_texture, sample_coordinate);
|
||||
#ifdef CUSTOM_ZERO_BORDER_MODE
|
||||
float out_of_bounds =
|
||||
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
||||
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
|
||||
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
|
||||
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
|
||||
fragColor = color;
|
||||
}
|
||||
)";
|
||||
#ifdef GL_ES
|
||||
#define fragColor gl_FragColor
|
||||
#else
|
||||
out vec4 fragColor;
|
||||
#endif // defined(GL_ES);
|
||||
|
||||
#ifdef CUBIC_INTERPOLATION
|
||||
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||
const vec2 halve = vec2(0.5,0.5);
|
||||
const vec2 one = vec2(1.0,1.0);
|
||||
const vec2 two = vec2(2.0,2.0);
|
||||
const vec2 three = vec2(3.0,3.0);
|
||||
const vec2 six = vec2(6.0,6.0);
|
||||
|
||||
// Calculate the fraction and integer.
|
||||
tex_coord = tex_coord * tex_size - halve;
|
||||
vec2 frac = fract(tex_coord);
|
||||
vec2 index = tex_coord - frac + halve;
|
||||
|
||||
// Calculate weights for Catmull-Rom filter.
|
||||
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
|
||||
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
|
||||
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
|
||||
vec2 w3 = frac * frac * (-halve + halve * frac);
|
||||
|
||||
// Calculate weights to take advantage of bilinear texture lookup.
|
||||
vec2 w12 = w1 + w2;
|
||||
vec2 offset12 = w2 / (w1 + w2);
|
||||
|
||||
vec2 index_tl = index - one;
|
||||
vec2 index_br = index + two;
|
||||
vec2 index_eq = index + offset12;
|
||||
|
||||
index_tl /= tex_size;
|
||||
index_br /= tex_size;
|
||||
index_eq /= tex_size;
|
||||
|
||||
// 9 texture lookup and linear blending.
|
||||
vec4 color = vec4(0.0);
|
||||
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
|
||||
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
|
||||
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
|
||||
|
||||
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
|
||||
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
|
||||
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
|
||||
|
||||
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
|
||||
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
|
||||
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
|
||||
return color;
|
||||
}
|
||||
#else
|
||||
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||
return texture2D(tex, tex_coord);
|
||||
}
|
||||
#endif // defined(CUBIC_INTERPOLATION)
|
||||
|
||||
void main() {
|
||||
vec4 color = sample(input_texture, sample_coordinate, input_size);
|
||||
#ifdef CUSTOM_ZERO_BORDER_MODE
|
||||
float out_of_bounds =
|
||||
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
||||
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
|
||||
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
|
||||
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
|
||||
fragColor = color;
|
||||
}
|
||||
)";
|
||||
|
||||
// Create program and set parameters.
|
||||
auto create_fn = [&](const std::string& vs,
|
||||
|
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
|
|||
glUseProgram(program);
|
||||
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
||||
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
||||
return Program{.id = program, .matrix_id = matrix_id};
|
||||
GLint size_id = glGetUniformLocation(program, "input_size");
|
||||
return Program{
|
||||
.id = program, .matrix_id = matrix_id, .size_id = size_id};
|
||||
};
|
||||
|
||||
const std::string vert_src =
|
||||
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
||||
|
||||
const std::string frag_src = absl::StrCat(
|
||||
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader);
|
||||
std::string interpolation_def;
|
||||
switch (interpolation_) {
|
||||
case AffineTransformation::Interpolation::kCubic:
|
||||
interpolation_def = R"(
|
||||
#define CUBIC_INTERPOLATION
|
||||
)";
|
||||
break;
|
||||
case AffineTransformation::Interpolation::kLinear:
|
||||
break;
|
||||
}
|
||||
|
||||
const std::string frag_src =
|
||||
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||
interpolation_def, kFragShader);
|
||||
|
||||
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
||||
|
||||
|
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
|
|||
std::string custom_zero_border_mode_def = R"(
|
||||
#define CUSTOM_ZERO_BORDER_MODE
|
||||
)";
|
||||
const std::string frag_custom_zero_src =
|
||||
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||
custom_zero_border_mode_def, kFragShader);
|
||||
const std::string frag_custom_zero_src = absl::StrCat(
|
||||
mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||
custom_zero_border_mode_def, interpolation_def, kFragShader);
|
||||
return create_fn(vert_src, frag_custom_zero_src);
|
||||
};
|
||||
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||
|
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
|
|||
}
|
||||
glUseProgram(program->id);
|
||||
|
||||
// uniforms
|
||||
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
||||
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
||||
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
||||
|
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
|
|||
eigen_mat.transposeInPlace();
|
||||
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
||||
|
||||
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
|
||||
glUniform2f(program->size_id, texture.width(), texture.height());
|
||||
}
|
||||
|
||||
// vao
|
||||
glBindVertexArray(vao_);
|
||||
|
||||
|
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
|
|||
struct Program {
|
||||
GLuint id;
|
||||
GLint matrix_id;
|
||||
GLint size_id;
|
||||
};
|
||||
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
||||
GpuOrigin::Mode gpu_origin_;
|
||||
|
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
|
|||
Program program_;
|
||||
std::optional<Program> program_custom_zero_;
|
||||
GLuint framebuffer_ = 0;
|
||||
AffineTransformation::Interpolation interpolation_ =
|
||||
AffineTransformation::Interpolation::kLinear;
|
||||
};
|
||||
|
||||
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||
|
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
|
|||
absl::StatusOr<std::unique_ptr<
|
||||
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
||||
CreateAffineTransformationGlRunner(
|
||||
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) {
|
||||
auto runner =
|
||||
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin);
|
||||
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
|
||||
AffineTransformation::Interpolation interpolation) {
|
||||
auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
|
||||
gl_helper, gpu_origin, interpolation);
|
||||
MP_RETURN_IF_ERROR(runner->Init());
|
||||
return runner;
|
||||
}
|
||||
|
|
|
@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
|
|||
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
||||
CreateAffineTransformationGlRunner(
|
||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
||||
mediapipe::GpuOrigin::Mode gpu_origin);
|
||||
mediapipe::GpuOrigin::Mode gpu_origin,
|
||||
AffineTransformation::Interpolation interpolation);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
|
|||
}
|
||||
}
|
||||
|
||||
int GetInterpolationForOpenCv(
|
||||
AffineTransformation::Interpolation interpolation) {
|
||||
switch (interpolation) {
|
||||
case AffineTransformation::Interpolation::kLinear:
|
||||
return cv::INTER_LINEAR;
|
||||
case AffineTransformation::Interpolation::kCubic:
|
||||
return cv::INTER_CUBIC;
|
||||
}
|
||||
}
|
||||
|
||||
class OpenCvRunner
|
||||
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
||||
public:
|
||||
OpenCvRunner(AffineTransformation::Interpolation interpolation)
|
||||
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
|
||||
|
||||
absl::StatusOr<ImageFrame> Run(
|
||||
const ImageFrame& input, const std::array<float, 16>& matrix,
|
||||
const AffineTransformation::Size& size,
|
||||
|
@ -142,19 +155,23 @@ class OpenCvRunner
|
|||
|
||||
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
||||
cv::Size(out_mat.cols, out_mat.rows),
|
||||
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP,
|
||||
/*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
|
||||
GetBorderModeForOpenCv(border_mode));
|
||||
|
||||
return out_image;
|
||||
}
|
||||
|
||||
private:
|
||||
int interpolation_ = cv::INTER_LINEAR;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<
|
||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||
CreateAffineTransformationOpenCvRunner() {
|
||||
return absl::make_unique<OpenCvRunner>();
|
||||
CreateAffineTransformationOpenCvRunner(
|
||||
AffineTransformation::Interpolation interpolation) {
|
||||
return absl::make_unique<OpenCvRunner>(interpolation);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -25,7 +25,8 @@ namespace mediapipe {
|
|||
|
||||
absl::StatusOr<
|
||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||
CreateAffineTransformationOpenCvRunner();
|
||||
CreateAffineTransformationOpenCvRunner(
|
||||
AffineTransformation::Interpolation interpolation);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
|
|||
}
|
||||
}
|
||||
|
||||
AffineTransformation::Interpolation GetInterpolation(
|
||||
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
|
||||
switch (interpolation) {
|
||||
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
|
||||
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
|
||||
return AffineTransformation::Interpolation::kLinear;
|
||||
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
|
||||
return AffineTransformation::Interpolation::kCubic;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ImageT>
|
||||
class WarpAffineRunnerHolder {};
|
||||
|
||||
|
@ -61,16 +72,22 @@ template <>
|
|||
class WarpAffineRunnerHolder<ImageFrame> {
|
||||
public:
|
||||
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
||||
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
|
||||
absl::Status Open(CalculatorContext* cc) {
|
||||
interpolation_ = GetInterpolation(
|
||||
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::StatusOr<RunnerType*> GetRunner() {
|
||||
if (!runner_) {
|
||||
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner());
|
||||
ASSIGN_OR_RETURN(runner_,
|
||||
CreateAffineTransformationOpenCvRunner(interpolation_));
|
||||
}
|
||||
return runner_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<RunnerType> runner_;
|
||||
AffineTransformation::Interpolation interpolation_;
|
||||
};
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
|
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
|||
gpu_origin_ =
|
||||
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
||||
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
||||
interpolation_ = GetInterpolation(
|
||||
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||
return gl_helper_->Open(cc);
|
||||
}
|
||||
absl::StatusOr<RunnerType*> GetRunner() {
|
||||
if (!runner_) {
|
||||
ASSIGN_OR_RETURN(
|
||||
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_));
|
||||
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
|
||||
gl_helper_, gpu_origin_, interpolation_));
|
||||
}
|
||||
return runner_.get();
|
||||
}
|
||||
|
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
|||
mediapipe::GpuOrigin::Mode gpu_origin_;
|
||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
||||
std::unique_ptr<RunnerType> runner_;
|
||||
AffineTransformation::Interpolation interpolation_;
|
||||
};
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
|
|
|
@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
|
|||
BORDER_REPLICATE = 2;
|
||||
}
|
||||
|
||||
// Pixel sampling interpolation methods. See @interpolation.
|
||||
enum Interpolation {
|
||||
INTER_UNSPECIFIED = 0;
|
||||
INTER_LINEAR = 1;
|
||||
INTER_CUBIC = 2;
|
||||
}
|
||||
|
||||
// Pixel extrapolation method.
|
||||
// When converting image to tensor it may happen that tensor needs to read
|
||||
// pixels outside image boundaries. Border mode helps to specify how such
|
||||
|
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
|
|||
// to be flipped vertically as tensors are expected to start at top.
|
||||
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
||||
optional GpuOrigin.Mode gpu_origin = 2;
|
||||
|
||||
// Sampling method for neighboring pixels.
|
||||
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
|
||||
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
|
||||
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
|
||||
optional Interpolation interpolation = 3;
|
||||
}
|
||||
|
|
|
@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
|||
const cv::Mat& input, cv::Mat expected_result,
|
||||
float similarity_threshold, std::array<float, 16> matrix,
|
||||
int out_width, int out_height,
|
||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
||||
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||
std::string border_mode_str;
|
||||
if (border_mode) {
|
||||
switch (*border_mode) {
|
||||
|
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
|||
break;
|
||||
}
|
||||
}
|
||||
std::string interpolation_str;
|
||||
if (interpolation) {
|
||||
switch (*interpolation) {
|
||||
case AffineTransformation::Interpolation::kLinear:
|
||||
interpolation_str = "interpolation: INTER_LINEAR";
|
||||
break;
|
||||
case AffineTransformation::Interpolation::kCubic:
|
||||
interpolation_str = "interpolation: INTER_CUBIC";
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(graph_text, /*$0=*/border_mode_str));
|
||||
absl::Substitute(graph_text, /*$0=*/border_mode_str,
|
||||
/*$1=*/interpolation_str));
|
||||
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
||||
|
@ -132,7 +145,8 @@ struct SimilarityConfig {
|
|||
void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
||||
int out_width, int out_height,
|
||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
||||
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||
RunTest(R"(
|
||||
input_stream: "input_image"
|
||||
input_stream: "output_size"
|
||||
|
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
options {
|
||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||
$0 # border mode
|
||||
$1 # interpolation
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
|
||||
RunTest(R"(
|
||||
input_stream: "input_image"
|
||||
|
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
options {
|
||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||
$0 # border mode
|
||||
$1 # interpolation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
}
|
||||
)",
|
||||
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
||||
matrix, out_width, out_height, border_mode);
|
||||
matrix, out_width, out_height, border_mode, interpolation);
|
||||
|
||||
RunTest(R"(
|
||||
input_stream: "input_image"
|
||||
|
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
options {
|
||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||
$0 # border mode
|
||||
$1 # interpolation
|
||||
gpu_origin: TOP_LEFT
|
||||
}
|
||||
}
|
||||
|
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
}
|
||||
)",
|
||||
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
|
||||
RunTest(R"(
|
||||
input_stream: "input_image"
|
||||
|
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
options {
|
||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||
$0 # border mode
|
||||
$1 # interpolation
|
||||
gpu_origin: TOP_LEFT
|
||||
}
|
||||
}
|
||||
|
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
|||
}
|
||||
)",
|
||||
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
||||
matrix, out_width, out_height, border_mode);
|
||||
matrix, out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
||||
|
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
|
|||
int out_height = 256;
|
||||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||
|
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||
|
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kReplicate;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||
|
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||
|
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
|||
bool keep_aspect_ratio = false;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kReplicate;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||
|
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
|||
bool keep_aspect_ratio = false;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
|
||||
mediapipe::NormalizedRect roi;
|
||||
roi.set_x_center(0.65f);
|
||||
roi.set_y_center(0.4f);
|
||||
roi.set_width(0.5f);
|
||||
roi.set_height(0.5f);
|
||||
roi.set_rotation(M_PI * -45.0f / 180.0f);
|
||||
auto input = GetRgb(
|
||||
"/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/input.jpg");
|
||||
auto expected_output = GetRgb(
|
||||
"/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
|
||||
int out_width = 256;
|
||||
int out_height = 256;
|
||||
bool keep_aspect_ratio = false;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation =
|
||||
AffineTransformation::Interpolation::kCubic;
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||
|
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
|||
bool keep_aspect_ratio = false;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kReplicate;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||
|
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
|||
bool keep_aspect_ratio = false;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||
|
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kReplicate;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||
|
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||
|
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
|||
int out_height = 128;
|
||||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||
|
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, NoOp) {
|
||||
|
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kReplicate;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||
|
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
|||
bool keep_aspect_ratio = true;
|
||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||
AffineTransformation::BorderMode::kZero;
|
||||
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||
RunTest(input, expected_output,
|
||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||
out_width, out_height, border_mode);
|
||||
out_width, out_height, border_mode, interpolation);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 64 KiB |
Loading…
Reference in New Issue
Block a user