Add gpu to cpu fallback for tensors_to_detections_calculator.
PiperOrigin-RevId: 544480883
This commit is contained in:
parent
0ea54b1461
commit
687075e5b8
|
@ -256,6 +256,7 @@ class TensorsToDetectionsCalculator : public Node {
|
||||||
|
|
||||||
bool gpu_inited_ = false;
|
bool gpu_inited_ = false;
|
||||||
bool gpu_input_ = false;
|
bool gpu_input_ = false;
|
||||||
|
bool gpu_has_enough_work_groups_ = true;
|
||||||
bool anchors_init_ = false;
|
bool anchors_init_ = false;
|
||||||
};
|
};
|
||||||
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
|
MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator);
|
||||||
|
@ -291,7 +292,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) {
|
||||||
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
||||||
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||||
bool gpu_processing = false;
|
bool gpu_processing = false;
|
||||||
if (CanUseGpu()) {
|
if (CanUseGpu() && gpu_has_enough_work_groups_) {
|
||||||
// Use GPU processing only if at least one input tensor is already on GPU
|
// Use GPU processing only if at least one input tensor is already on GPU
|
||||||
// (to avoid CPU->GPU overhead).
|
// (to avoid CPU->GPU overhead).
|
||||||
for (const auto& tensor : *kInTensors(cc)) {
|
for (const auto& tensor : *kInTensors(cc)) {
|
||||||
|
@ -321,11 +322,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
||||||
RET_CHECK(!has_custom_box_indices_);
|
RET_CHECK(!has_custom_box_indices_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gpu_processing) {
|
if (gpu_processing && !gpu_inited_) {
|
||||||
if (!gpu_inited_) {
|
auto status = GpuInit(cc);
|
||||||
MP_RETURN_IF_ERROR(GpuInit(cc));
|
if (status.ok()) {
|
||||||
gpu_inited_ = true;
|
gpu_inited_ = true;
|
||||||
|
} else if (status.code() == absl::StatusCode::kFailedPrecondition) {
|
||||||
|
// For initialization error because of hardware limitation, fallback to
|
||||||
|
// CPU processing.
|
||||||
|
LOG(WARNING) << status.message();
|
||||||
|
} else {
|
||||||
|
// For other error, let the error propagates.
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (gpu_processing && gpu_inited_) {
|
||||||
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
|
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
|
||||||
} else {
|
} else {
|
||||||
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
|
||||||
|
@ -346,17 +356,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
||||||
// TODO: Add flexible input tensor size handling.
|
// TODO: Add flexible input tensor size handling.
|
||||||
auto raw_box_tensor =
|
auto raw_box_tensor =
|
||||||
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
|
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
|
||||||
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
|
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
|
if (raw_box_tensor->shape().dims.size() == 3) {
|
||||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
|
// The tensors from CPU inference has dim 3.
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
|
||||||
|
} else if (raw_box_tensor->shape().dims.size() == 4) {
|
||||||
|
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
|
||||||
|
// we allow tensors with 4 dims.
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_);
|
||||||
|
} else {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The dimensions of box Tensor must be 3 or 4.");
|
||||||
|
}
|
||||||
auto raw_score_tensor =
|
auto raw_score_tensor =
|
||||||
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
|
if (raw_score_tensor->shape().dims.size() == 3) {
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
// The tensors from CPU inference has dim 3.
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_);
|
||||||
|
} else if (raw_score_tensor->shape().dims.size() == 4) {
|
||||||
|
// The tensors from GPU inference has dim 4. For gpu-cpu fallback support,
|
||||||
|
// we allow tensors with 4 dims.
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_);
|
||||||
|
RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_);
|
||||||
|
} else {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The dimensions of score Tensor must be 3 or 4.");
|
||||||
|
}
|
||||||
auto raw_box_view = raw_box_tensor->GetCpuReadView();
|
auto raw_box_view = raw_box_tensor->GetCpuReadView();
|
||||||
auto raw_boxes = raw_box_view.buffer<float>();
|
auto raw_boxes = raw_box_view.buffer<float>();
|
||||||
auto raw_scores_view = raw_score_tensor->GetCpuReadView();
|
auto raw_scores_view = raw_score_tensor->GetCpuReadView();
|
||||||
|
@ -1111,8 +1145,13 @@ void main() {
|
||||||
int max_wg_size; // typically <= 1024
|
int max_wg_size; // typically <= 1024
|
||||||
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
|
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
|
||||||
&max_wg_size); // y-dim
|
&max_wg_size); // y-dim
|
||||||
CHECK_LT(num_classes_, max_wg_size)
|
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
|
||||||
<< "# classes must be < " << max_wg_size;
|
if (!gpu_has_enough_work_groups_) {
|
||||||
|
return absl::FailedPreconditionError(absl::StrFormat(
|
||||||
|
"Hardware limitation: Processing will be done on CPU, because "
|
||||||
|
"num_classes %d exceeds the max work_group size %d.",
|
||||||
|
num_classes_, max_wg_size));
|
||||||
|
}
|
||||||
// TODO support better filtering.
|
// TODO support better filtering.
|
||||||
if (class_index_set_.is_allowlist) {
|
if (class_index_set_.is_allowlist) {
|
||||||
CHECK_EQ(class_index_set_.values.size(),
|
CHECK_EQ(class_index_set_.values.size(),
|
||||||
|
@ -1370,7 +1409,13 @@ kernel void scoreKernel(
|
||||||
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
|
Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2});
|
||||||
// # filter classes supported is hardware dependent.
|
// # filter classes supported is hardware dependent.
|
||||||
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
|
int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup;
|
||||||
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
|
gpu_has_enough_work_groups_ = num_classes_ < max_wg_size;
|
||||||
|
if (!gpu_has_enough_work_groups_) {
|
||||||
|
return absl::FailedPreconditionError(absl::StrFormat(
|
||||||
|
"Hardware limitation: Processing will be done on CPU, because "
|
||||||
|
"num_classes %d exceeds the max work_group size %d.",
|
||||||
|
num_classes_, max_wg_size));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user