Added the gray scale image support for the ImageToTensorCalculator on CPU.
PiperOrigin-RevId: 489593917
This commit is contained in:
parent
524ac3ca61
commit
bbd5da7971
|
@ -54,6 +54,13 @@ cv::Mat GetRgba(absl::string_view path) {
|
|||
return rgb;
|
||||
}
|
||||
|
||||
cv::Mat GetGray(absl::string_view path) {
|
||||
cv::Mat bgr = cv::imread(file::JoinPath("./", path));
|
||||
cv::Mat gray;
|
||||
cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY);
|
||||
return gray;
|
||||
}
|
||||
|
||||
// Image to tensor test template.
|
||||
// No processing/assertions should be done after the function is invoked.
|
||||
void RunTestWithInputImagePacket(const Packet& input_image_packet,
|
||||
|
@ -147,29 +154,34 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
|
|||
ASSERT_THAT(tensor_vec, testing::SizeIs(1));
|
||||
|
||||
const Tensor& tensor = tensor_vec[0];
|
||||
const int channels = tensor.shape().dims[3];
|
||||
ASSERT_TRUE(channels == 1 || channels == 3);
|
||||
auto view = tensor.GetCpuReadView();
|
||||
cv::Mat tensor_mat;
|
||||
if (output_int_tensor) {
|
||||
if (range_min < 0) {
|
||||
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8);
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3,
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width,
|
||||
channels == 1 ? CV_8SC1 : CV_8SC3,
|
||||
const_cast<int8*>(view.buffer<int8>()));
|
||||
} else {
|
||||
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8);
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3,
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width,
|
||||
channels == 1 ? CV_8UC1 : CV_8UC3,
|
||||
const_cast<uint8*>(view.buffer<uint8>()));
|
||||
}
|
||||
} else {
|
||||
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32);
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3,
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width,
|
||||
channels == 1 ? CV_32FC1 : CV_32FC3,
|
||||
const_cast<float*>(view.buffer<float>()));
|
||||
}
|
||||
|
||||
cv::Mat result_rgb;
|
||||
auto transformation =
|
||||
GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value();
|
||||
tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale,
|
||||
transformation.offset);
|
||||
tensor_mat.convertTo(result_rgb, channels == 1 ? CV_8UC1 : CV_8UC3,
|
||||
transformation.scale, transformation.offset);
|
||||
|
||||
cv::Mat diff;
|
||||
cv::absdiff(result_rgb, expected_result, diff);
|
||||
|
@ -185,17 +197,27 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
|
|||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
|
||||
if (image_channels == 4) {
|
||||
return ImageFormat::SRGBA;
|
||||
} else if (image_channels == 3) {
|
||||
return ImageFormat::SRGB;
|
||||
} else if (image_channels == 1) {
|
||||
return ImageFormat::GRAY8;
|
||||
}
|
||||
CHECK(false) << "Unsupported input image channles: " << image_channels;
|
||||
}
|
||||
|
||||
Packet MakeImageFramePacket(cv::Mat input) {
|
||||
ImageFrame input_image(
|
||||
input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB,
|
||||
input.cols, input.rows, input.step, input.data, [](uint8*) {});
|
||||
ImageFrame input_image(GetImageFormat(input.channels()), input.cols,
|
||||
input.rows, input.step, input.data, [](uint8*) {});
|
||||
return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0));
|
||||
}
|
||||
|
||||
Packet MakeImagePacket(cv::Mat input) {
|
||||
mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>(
|
||||
input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB,
|
||||
input.cols, input.rows, input.step, input.data, [](uint8*) {}));
|
||||
GetImageFormat(input.channels()), input.cols, input.rows, input.step,
|
||||
input.data, [](uint8*) {}));
|
||||
return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0));
|
||||
}
|
||||
|
||||
|
@ -429,6 +451,24 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
|||
/*border_mode=*/{}, roi);
|
||||
}
|
||||
|
||||
TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) {
|
||||
mediapipe::NormalizedRect roi;
|
||||
roi.set_x_center(0.5f);
|
||||
roi.set_y_center(0.5f);
|
||||
roi.set_width(1.5f);
|
||||
roi.set_height(1.1f);
|
||||
roi.set_rotation(M_PI * -15.0f / 180.0f);
|
||||
RunTest(GetGray("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetGray("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"large_sub_rect_keep_aspect_with_rotation.png"),
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
/*border_mode=*/{}, roi);
|
||||
}
|
||||
|
||||
TEST(ImageToTensorCalculatorTest,
|
||||
LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||
mediapipe::NormalizedRect roi;
|
||||
|
@ -448,6 +488,25 @@ TEST(ImageToTensorCalculatorTest,
|
|||
/*border_mode=*/BorderMode::kZero, roi);
|
||||
}
|
||||
|
||||
TEST(ImageToTensorCalculatorTest,
|
||||
LargeSubRectKeepAspectWithRotationBorderZeroGray) {
|
||||
mediapipe::NormalizedRect roi;
|
||||
roi.set_x_center(0.5f);
|
||||
roi.set_y_center(0.5f);
|
||||
roi.set_width(1.5f);
|
||||
roi.set_height(1.1f);
|
||||
roi.set_rotation(M_PI * -15.0f / 180.0f);
|
||||
RunTest(GetGray("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetGray("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"large_sub_rect_keep_aspect_with_rotation_border_zero.png"),
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
/*border_mode=*/BorderMode::kZero, roi);
|
||||
}
|
||||
|
||||
TEST(ImageToTensorCalculatorTest, NoOpExceptRange) {
|
||||
mediapipe::NormalizedRect roi;
|
||||
roi.set_x_center(0.5f);
|
||||
|
|
|
@ -48,15 +48,19 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
mat_type_ = CV_8SC3;
|
||||
mat_gray_type_ = CV_8SC1;
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
mat_type_ = CV_32FC3;
|
||||
mat_gray_type_ = CV_32FC1;
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
mat_type_ = CV_8UC3;
|
||||
mat_gray_type_ = CV_8UC1;
|
||||
break;
|
||||
default:
|
||||
mat_type_ = -1;
|
||||
mat_gray_type_ = -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,11 +68,13 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
|
||||
input.image_format() != mediapipe::ImageFormat::SRGBA) {
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.image_format())));
|
||||
const bool is_supported_format =
|
||||
input.image_format() == mediapipe::ImageFormat::SRGB ||
|
||||
input.image_format() == mediapipe::ImageFormat::SRGBA ||
|
||||
input.image_format() == mediapipe::ImageFormat::GRAY8;
|
||||
if (!is_supported_format) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.image_format())));
|
||||
}
|
||||
// TODO: Remove the check once tensor_buffer_offset > 0 is
|
||||
// supported.
|
||||
|
@ -82,17 +88,18 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
const int output_channels = output_shape.dims[3];
|
||||
auto buffer_view = output_tensor.GetCpuWriteView();
|
||||
cv::Mat dst;
|
||||
const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_;
|
||||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<int8>());
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<float>());
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<uint8>());
|
||||
break;
|
||||
default:
|
||||
|
@ -137,7 +144,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
auto transform,
|
||||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
|
||||
transformed.convertTo(dst, dst_data_type, transform.scale,
|
||||
transform.offset);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -148,7 +156,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -156,6 +164,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
enum cv::BorderTypes border_mode_;
|
||||
Tensor::ElementType tensor_type_;
|
||||
int mat_type_;
|
||||
int mat_gray_type_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -253,8 +253,11 @@ int GetNumOutputChannels(const mediapipe::Image& image) {
|
|||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
// The output tensor channel is 1 for the input image with 1 channel; And the
|
||||
// output tensor channels is 3 for the input image with 3 or 4 channels.
|
||||
// TODO: Add a unittest here to test the behavior on GPU, i.e.
|
||||
// failure.
|
||||
return image.channels() == 1 ? 1 : 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
|
|
Loading…
Reference in New Issue
Block a user