2019-06-17 01:03:25 +02:00
|
|
|
// Copyright 2019 The MediaPipe Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
#include "mediapipe/util/tensor_to_detection.h"
|
|
|
|
|
|
|
|
#include "absl/strings/str_format.h"
|
|
|
|
#include "absl/types/variant.h"
|
|
|
|
#include "mediapipe/framework/formats/location.h"
|
2022-09-06 23:29:51 +02:00
|
|
|
#include "mediapipe/framework/formats/location_opencv.h"
|
2019-06-17 01:03:25 +02:00
|
|
|
#include "mediapipe/framework/port/canonical_errors.h"
|
|
|
|
#include "mediapipe/framework/port/map_util.h"
|
|
|
|
#include "mediapipe/framework/port/status.h"
|
|
|
|
#include "tensorflow/core/framework/tensor_types.h"
|
|
|
|
|
|
|
|
namespace mediapipe {
|
|
|
|
|
|
|
|
using ::absl::StrFormat;
|
|
|
|
|
|
|
|
namespace tf = ::tensorflow;
|
|
|
|
|
|
|
|
Detection TensorToDetection(
|
|
|
|
float box_ymin, float box_xmin, float box_ymax, float box_xmax,
|
|
|
|
const float score, const absl::variant<int, std::string>& class_label) {
|
|
|
|
Detection detection;
|
|
|
|
detection.add_score(score);
|
|
|
|
|
|
|
|
// According to mediapipe/framework/formats/detection.proto
|
2022-03-21 20:07:37 +01:00
|
|
|
// "Either string or integer labels must be used but not both at the
|
2019-06-17 01:03:25 +02:00
|
|
|
// same time."
|
|
|
|
if (absl::holds_alternative<int>(class_label)) {
|
|
|
|
detection.add_label_id(absl::get<int>(class_label));
|
|
|
|
} else {
|
|
|
|
detection.add_label(absl::get<std::string>(class_label));
|
|
|
|
}
|
|
|
|
|
|
|
|
LocationData* location_data = detection.mutable_location_data();
|
|
|
|
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
|
|
|
|
|
|
|
LocationData::RelativeBoundingBox* relative_bbox =
|
|
|
|
location_data->mutable_relative_bounding_box();
|
|
|
|
|
|
|
|
relative_bbox->set_xmin(box_xmin);
|
|
|
|
relative_bbox->set_ymin(box_ymin);
|
|
|
|
relative_bbox->set_width(box_xmax - box_xmin);
|
|
|
|
relative_bbox->set_height(box_ymax - box_ymin);
|
|
|
|
return detection;
|
|
|
|
}
|
|
|
|
|
|
|
|
Status TensorsToDetections(const ::tensorflow::Tensor& num_detections,
|
|
|
|
const ::tensorflow::Tensor& boxes,
|
|
|
|
const ::tensorflow::Tensor& scores,
|
|
|
|
const ::tensorflow::Tensor& classes,
|
|
|
|
const std::map<int, std::string>& label_map,
|
|
|
|
std::vector<Detection>* detections) {
|
|
|
|
const ::tensorflow::Tensor empty_keypoints = ::tensorflow::Tensor(
|
|
|
|
::tensorflow::DT_FLOAT, ::tensorflow::TensorShape({0, 0, 0}));
|
|
|
|
const ::tensorflow::Tensor empty_masks = ::tensorflow::Tensor(
|
|
|
|
::tensorflow::DT_FLOAT, ::tensorflow::TensorShape({0, 0, 0}));
|
|
|
|
return TensorsToDetections(num_detections, boxes, scores, classes,
|
|
|
|
empty_keypoints, empty_masks,
|
|
|
|
/*mask_threshold=*/0.0f, label_map, detections);
|
|
|
|
}
|
|
|
|
|
|
|
|
Status TensorsToDetections(const ::tensorflow::Tensor& num_detections,
|
|
|
|
const ::tensorflow::Tensor& boxes,
|
|
|
|
const ::tensorflow::Tensor& scores,
|
|
|
|
const ::tensorflow::Tensor& classes,
|
|
|
|
const ::tensorflow::Tensor& keypoints,
|
|
|
|
const ::tensorflow::Tensor& masks,
|
|
|
|
float mask_threshold,
|
|
|
|
const std::map<int, std::string>& label_map,
|
|
|
|
std::vector<Detection>* detections) {
|
|
|
|
int num_boxes = -1;
|
|
|
|
if (num_detections.dims() > 0 && num_detections.dim_size(0) != 0) {
|
|
|
|
if (num_detections.dtype() != tf::DT_INT32) {
|
|
|
|
const auto& num_boxes_scalar = num_detections.scalar<float>();
|
|
|
|
num_boxes = static_cast<int>(num_boxes_scalar());
|
|
|
|
} else {
|
|
|
|
num_boxes = num_detections.scalar<int32>()();
|
|
|
|
}
|
|
|
|
if (boxes.dim_size(0) < num_boxes) {
|
|
|
|
return InvalidArgumentError(
|
|
|
|
"First dimension of boxes tensor must be at least num_boxes");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (classes.dim_size(0) != 0 && classes.dim_size(0) < num_boxes) {
|
|
|
|
return InvalidArgumentError(
|
|
|
|
"First dimension of classes tensor must be at least num_boxes");
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
// If num_detections is not present, the number of boxes is determined by
|
|
|
|
// the first dimension of the box tensor.
|
|
|
|
if (boxes.dim_size(0) <= 0) {
|
|
|
|
return InvalidArgumentError("Box tensor is empty");
|
|
|
|
}
|
|
|
|
num_boxes = boxes.dim_size(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (scores.dim_size(0) < num_boxes) {
|
|
|
|
return InvalidArgumentError(
|
|
|
|
"First dimension of scores tensor must be at least num_boxes");
|
|
|
|
}
|
|
|
|
if (keypoints.dim_size(0) != 0 && keypoints.dim_size(0) < num_boxes) {
|
|
|
|
return InvalidArgumentError(
|
|
|
|
"First dimension of keypoint tensors must be at least num_boxes");
|
|
|
|
}
|
|
|
|
int num_keypoints = keypoints.dim_size(1);
|
|
|
|
|
|
|
|
if (masks.dim_size(0) != 0 && masks.dim_size(0) < num_boxes) {
|
|
|
|
return InvalidArgumentError(
|
|
|
|
"First dimension of the masks tensor should be at least num_boxes");
|
|
|
|
}
|
|
|
|
|
|
|
|
const auto& score_vec =
|
|
|
|
scores.dims() > 1 ? scores.flat<float>() : scores.vec<float>();
|
|
|
|
const auto& classes_vec = classes.vec<float>();
|
|
|
|
const auto& boxes_mat = boxes.tensor<float, 2>();
|
|
|
|
const auto& keypoints_mat = keypoints.tensor<float, 3>();
|
|
|
|
const auto& masks_mat = masks.tensor<float, 3>();
|
|
|
|
|
|
|
|
for (int i = 0; i < num_boxes; ++i) {
|
|
|
|
int class_id = -1;
|
|
|
|
float score = -std::numeric_limits<float>::max();
|
|
|
|
if (classes.dim_size(0) == 0) {
|
|
|
|
// If class tensor is missing, we will sort the scores of all classes for
|
|
|
|
// each box and keep the top.
|
|
|
|
if (scores.dims() != 2) {
|
|
|
|
return InvalidArgumentError(
|
|
|
|
"Score tensor must have 2 dimensions where the last dimension has "
|
|
|
|
"the scores for each class");
|
|
|
|
}
|
|
|
|
const int num_class = scores.dim_size(1);
|
|
|
|
// Find the top score for box i.
|
|
|
|
for (int score_idx = 0; score_idx < num_class; ++score_idx) {
|
|
|
|
const auto score_for_class = score_vec(i * num_class + score_idx);
|
|
|
|
if (score < score_for_class) {
|
|
|
|
score = score_for_class;
|
|
|
|
class_id = score_idx;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// If class tensor and score tensor are both present, we use the them
|
|
|
|
// directly.
|
|
|
|
if (scores.dims() != 1) {
|
|
|
|
return InvalidArgumentError("Score tensor has more than 1 dimensions");
|
|
|
|
}
|
|
|
|
score = score_vec(i);
|
|
|
|
class_id = static_cast<int>(classes_vec(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Boxes is a tensor with shape (num_boxes x 4).
|
|
|
|
// We extract the (1 x 4) slice corresponding to the i-th box and flatten
|
|
|
|
// it into a vector of length 4.
|
|
|
|
Detection detection;
|
|
|
|
if (label_map.empty()) {
|
|
|
|
detection =
|
|
|
|
TensorToDetection(boxes_mat(i, 0), boxes_mat(i, 1), boxes_mat(i, 2),
|
|
|
|
boxes_mat(i, 3), score, class_id);
|
|
|
|
} else {
|
2020-12-10 04:13:05 +01:00
|
|
|
if (!mediapipe::ContainsKey(label_map, class_id)) {
|
2019-06-17 01:03:25 +02:00
|
|
|
return InvalidArgumentError(StrFormat(
|
|
|
|
"Input label_map does not contain entry for integer label: %d",
|
|
|
|
class_id));
|
|
|
|
}
|
2020-12-10 04:13:05 +01:00
|
|
|
detection = TensorToDetection(boxes_mat(i, 0), boxes_mat(i, 1),
|
|
|
|
boxes_mat(i, 2), boxes_mat(i, 3), score,
|
|
|
|
mediapipe::FindOrDie(label_map, class_id));
|
2019-06-17 01:03:25 +02:00
|
|
|
}
|
|
|
|
// Adding keypoints
|
|
|
|
LocationData* location_data = detection.mutable_location_data();
|
|
|
|
for (int j = 0; j < num_keypoints; ++j) {
|
|
|
|
auto* keypoint = location_data->add_relative_keypoints();
|
|
|
|
keypoint->set_y(keypoints_mat(i, j, 0));
|
|
|
|
keypoint->set_x(keypoints_mat(i, j, 1));
|
|
|
|
}
|
|
|
|
// Adding masks
|
|
|
|
if (masks.dim_size(0) != 0) {
|
|
|
|
cv::Mat mask_image(cv::Size(masks.dim_size(2), masks.dim_size(1)),
|
|
|
|
CV_32FC1);
|
|
|
|
for (int h = 0; h < masks.dim_size(1); ++h) {
|
|
|
|
for (int w = 0; w < masks.dim_size(2); ++w) {
|
|
|
|
const float value = masks_mat(i, h, w);
|
|
|
|
mask_image.at<float>(h, w) = value > mask_threshold ? value : 0.0f;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
LocationData mask_location_data;
|
2022-09-06 23:29:51 +02:00
|
|
|
mediapipe::CreateCvMaskLocation<float>(mask_image)
|
2019-06-17 01:03:25 +02:00
|
|
|
.ConvertToProto(&mask_location_data);
|
|
|
|
location_data->MergeFrom(mask_location_data);
|
|
|
|
}
|
|
|
|
detections->emplace_back(detection);
|
|
|
|
}
|
2021-02-27 09:21:16 +01:00
|
|
|
return absl::OkStatus();
|
2019-06-17 01:03:25 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace mediapipe
|