Internal change

PiperOrigin-RevId: 490089940
This commit is contained in:
MediaPipe Team 2022-11-21 15:45:29 -08:00 committed by Copybara-Service
parent adddf2c2ab
commit d43d0ff615

View File

@ -76,31 +76,49 @@ class OpenCvProcessor : public ImageToTensorConverter {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Unsupported format: ", static_cast<uint32_t>(input.image_format()))); "Unsupported format: ", static_cast<uint32_t>(input.image_format())));
} }
// TODO: Remove the check once tensor_buffer_offset > 0 is
// supported. RET_CHECK_GE(tensor_buffer_offset, 0)
RET_CHECK_EQ(tensor_buffer_offset, 0) << "The input tensor_buffer_offset needs to be non-negative.";
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape(); const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
const int output_height = output_shape.dims[1]; const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2]; const int output_width = output_shape.dims[2];
const int output_channels = output_shape.dims[3]; const int output_channels = output_shape.dims[3];
const int num_elements_per_img =
output_height * output_width * output_channels;
auto buffer_view = output_tensor.GetCpuWriteView(); auto buffer_view = output_tensor.GetCpuWriteView();
cv::Mat dst; cv::Mat dst;
const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_;
switch (tensor_type_) { switch (tensor_type_) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
dst = cv::Mat(output_height, output_width, dst_data_type, RET_CHECK_GE(output_shape.num_elements(),
buffer_view.buffer<int8>()); tensor_buffer_offset / sizeof(int8) + num_elements_per_img)
<< "The buffer offset + the input image size is larger than the "
"allocated tensor buffer.";
dst = cv::Mat(
output_height, output_width, dst_data_type,
buffer_view.buffer<int8>() + tensor_buffer_offset / sizeof(int8));
break; break;
case Tensor::ElementType::kFloat32: case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_height, output_width, dst_data_type, RET_CHECK_GE(
buffer_view.buffer<float>()); output_shape.num_elements(),
tensor_buffer_offset / sizeof(float) + num_elements_per_img)
<< "The buffer offset + the input image size is larger than the "
"allocated tensor buffer.";
dst = cv::Mat(
output_height, output_width, dst_data_type,
buffer_view.buffer<float>() + tensor_buffer_offset / sizeof(float));
break; break;
case Tensor::ElementType::kUInt8: case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_height, output_width, dst_data_type, RET_CHECK_GE(
buffer_view.buffer<uint8>()); output_shape.num_elements(),
tensor_buffer_offset / sizeof(uint8) + num_elements_per_img)
<< "The buffer offset + the input image size is larger than the "
"allocated tensor buffer.";
dst = cv::Mat(
output_height, output_width, dst_data_type,
buffer_view.buffer<uint8>() + tensor_buffer_offset / sizeof(uint8));
break; break;
default: default:
return InvalidArgumentError( return InvalidArgumentError(
@ -153,9 +171,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4) RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size(); << "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1) RET_CHECK_GE(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this " << "The batch dimension needs to be equal or larger than 1.";
"converter.";
RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1)
<< "Wrong output channel: " << output_shape.dims[3]; << "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus(); return absl::OkStatus();