Internal change
PiperOrigin-RevId: 490089940
This commit is contained in:
parent
adddf2c2ab
commit
d43d0ff615
|
@ -76,31 +76,49 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
return InvalidArgumentError(absl::StrCat(
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.image_format())));
|
||||
}
|
||||
// TODO: Remove the check once tensor_buffer_offset > 0 is
|
||||
// supported.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
|
||||
RET_CHECK_GE(tensor_buffer_offset, 0)
|
||||
<< "The input tensor_buffer_offset needs to be non-negative.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
const int output_height = output_shape.dims[1];
|
||||
const int output_width = output_shape.dims[2];
|
||||
const int output_channels = output_shape.dims[3];
|
||||
const int num_elements_per_img =
|
||||
output_height * output_width * output_channels;
|
||||
auto buffer_view = output_tensor.GetCpuWriteView();
|
||||
cv::Mat dst;
|
||||
const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_;
|
||||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<int8>());
|
||||
RET_CHECK_GE(output_shape.num_elements(),
|
||||
tensor_buffer_offset / sizeof(int8) + num_elements_per_img)
|
||||
<< "The buffer offset + the input image size is larger than the "
|
||||
"allocated tensor buffer.";
|
||||
dst = cv::Mat(
|
||||
output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<int8>() + tensor_buffer_offset / sizeof(int8));
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<float>());
|
||||
RET_CHECK_GE(
|
||||
output_shape.num_elements(),
|
||||
tensor_buffer_offset / sizeof(float) + num_elements_per_img)
|
||||
<< "The buffer offset + the input image size is larger than the "
|
||||
"allocated tensor buffer.";
|
||||
dst = cv::Mat(
|
||||
output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<float>() + tensor_buffer_offset / sizeof(float));
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<uint8>());
|
||||
RET_CHECK_GE(
|
||||
output_shape.num_elements(),
|
||||
tensor_buffer_offset / sizeof(uint8) + num_elements_per_img)
|
||||
<< "The buffer offset + the input image size is larger than the "
|
||||
"allocated tensor buffer.";
|
||||
dst = cv::Mat(
|
||||
output_height, output_width, dst_data_type,
|
||||
buffer_view.buffer<uint8>() + tensor_buffer_offset / sizeof(uint8));
|
||||
break;
|
||||
default:
|
||||
return InvalidArgumentError(
|
||||
|
@ -153,9 +171,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_GE(output_shape.dims[0], 1)
|
||||
<< "The batch dimension needs to be equal or larger than 1.";
|
||||
RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
|
|
Loading…
Reference in New Issue
Block a user