Internal change
PiperOrigin-RevId: 490089940
This commit is contained in:
parent
adddf2c2ab
commit
d43d0ff615
|
@ -76,31 +76,49 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
return InvalidArgumentError(absl::StrCat(
|
return InvalidArgumentError(absl::StrCat(
|
||||||
"Unsupported format: ", static_cast<uint32_t>(input.image_format())));
|
"Unsupported format: ", static_cast<uint32_t>(input.image_format())));
|
||||||
}
|
}
|
||||||
// TODO: Remove the check once tensor_buffer_offset > 0 is
|
|
||||||
// supported.
|
RET_CHECK_GE(tensor_buffer_offset, 0)
|
||||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
<< "The input tensor_buffer_offset needs to be non-negative.";
|
||||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
|
||||||
const auto& output_shape = output_tensor.shape();
|
const auto& output_shape = output_tensor.shape();
|
||||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||||
|
|
||||||
const int output_height = output_shape.dims[1];
|
const int output_height = output_shape.dims[1];
|
||||||
const int output_width = output_shape.dims[2];
|
const int output_width = output_shape.dims[2];
|
||||||
const int output_channels = output_shape.dims[3];
|
const int output_channels = output_shape.dims[3];
|
||||||
|
const int num_elements_per_img =
|
||||||
|
output_height * output_width * output_channels;
|
||||||
auto buffer_view = output_tensor.GetCpuWriteView();
|
auto buffer_view = output_tensor.GetCpuWriteView();
|
||||||
cv::Mat dst;
|
cv::Mat dst;
|
||||||
const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_;
|
const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_;
|
||||||
switch (tensor_type_) {
|
switch (tensor_type_) {
|
||||||
case Tensor::ElementType::kInt8:
|
case Tensor::ElementType::kInt8:
|
||||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
RET_CHECK_GE(output_shape.num_elements(),
|
||||||
buffer_view.buffer<int8>());
|
tensor_buffer_offset / sizeof(int8) + num_elements_per_img)
|
||||||
|
<< "The buffer offset + the input image size is larger than the "
|
||||||
|
"allocated tensor buffer.";
|
||||||
|
dst = cv::Mat(
|
||||||
|
output_height, output_width, dst_data_type,
|
||||||
|
buffer_view.buffer<int8>() + tensor_buffer_offset / sizeof(int8));
|
||||||
break;
|
break;
|
||||||
case Tensor::ElementType::kFloat32:
|
case Tensor::ElementType::kFloat32:
|
||||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
RET_CHECK_GE(
|
||||||
buffer_view.buffer<float>());
|
output_shape.num_elements(),
|
||||||
|
tensor_buffer_offset / sizeof(float) + num_elements_per_img)
|
||||||
|
<< "The buffer offset + the input image size is larger than the "
|
||||||
|
"allocated tensor buffer.";
|
||||||
|
dst = cv::Mat(
|
||||||
|
output_height, output_width, dst_data_type,
|
||||||
|
buffer_view.buffer<float>() + tensor_buffer_offset / sizeof(float));
|
||||||
break;
|
break;
|
||||||
case Tensor::ElementType::kUInt8:
|
case Tensor::ElementType::kUInt8:
|
||||||
dst = cv::Mat(output_height, output_width, dst_data_type,
|
RET_CHECK_GE(
|
||||||
buffer_view.buffer<uint8>());
|
output_shape.num_elements(),
|
||||||
|
tensor_buffer_offset / sizeof(uint8) + num_elements_per_img)
|
||||||
|
<< "The buffer offset + the input image size is larger than the "
|
||||||
|
"allocated tensor buffer.";
|
||||||
|
dst = cv::Mat(
|
||||||
|
output_height, output_width, dst_data_type,
|
||||||
|
buffer_view.buffer<uint8>() + tensor_buffer_offset / sizeof(uint8));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return InvalidArgumentError(
|
return InvalidArgumentError(
|
||||||
|
@ -153,9 +171,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
||||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
RET_CHECK_GE(output_shape.dims[0], 1)
|
||||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
<< "The batch dimension needs to be equal or larger than 1.";
|
||||||
"converter.";
|
|
||||||
RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1)
|
RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1)
|
||||||
<< "Wrong output channel: " << output_shape.dims[3];
|
<< "Wrong output channel: " << output_shape.dims[3];
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
|
Loading…
Reference in New Issue
Block a user