331 lines
13 KiB
C++
331 lines
13 KiB
C++
// Copyright 2022 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/status/status.h"
|
|
#include "absl/status/statusor.h"
|
|
#include "mediapipe/framework/api2/node.h"
|
|
#include "mediapipe/framework/api2/packet.h"
|
|
#include "mediapipe/framework/api2/port.h"
|
|
#include "mediapipe/framework/calculator_framework.h"
|
|
#include "mediapipe/framework/formats/detection.pb.h"
|
|
#include "mediapipe/framework/formats/location_data.pb.h"
|
|
|
|
namespace mediapipe {
|
|
namespace api2 {
|
|
namespace {
|
|
|
|
template <typename T>
|
|
T BoundedValue(T value, T upper_bound) {
|
|
T output = std::min(value, upper_bound);
|
|
if (output < 0) {
|
|
return 0;
|
|
}
|
|
return output;
|
|
}
|
|
|
|
absl::Status ConvertRelativeBoundingBoxToBoundingBox(
|
|
const std::pair<int, int>& image_size, Detection* detection) {
|
|
const int image_width = image_size.first;
|
|
const int image_height = image_size.second;
|
|
const auto& relative_bbox =
|
|
detection->location_data().relative_bounding_box();
|
|
auto* bbox = detection->mutable_location_data()->mutable_bounding_box();
|
|
bbox->set_xmin(
|
|
BoundedValue<int>(relative_bbox.xmin() * image_width, image_width));
|
|
bbox->set_ymin(
|
|
BoundedValue<int>(relative_bbox.ymin() * image_height, image_height));
|
|
bbox->set_width(
|
|
BoundedValue<int>(relative_bbox.width() * image_width, image_width));
|
|
bbox->set_height(
|
|
BoundedValue<int>(relative_bbox.height() * image_height, image_height));
|
|
detection->mutable_location_data()->set_format(LocationData::BOUNDING_BOX);
|
|
detection->mutable_location_data()->clear_relative_bounding_box();
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status ConvertBoundingBoxToRelativeBoundingBox(
|
|
const std::pair<int, int>& image_size, Detection* detection) {
|
|
int image_width = image_size.first;
|
|
int image_height = image_size.second;
|
|
const auto& bbox = detection->location_data().bounding_box();
|
|
auto* relative_bbox =
|
|
detection->mutable_location_data()->mutable_relative_bounding_box();
|
|
relative_bbox->set_xmin(
|
|
BoundedValue<float>((float)bbox.xmin() / image_width, 1.0f));
|
|
relative_bbox->set_ymin(
|
|
BoundedValue<float>((float)bbox.ymin() / image_height, 1.0f));
|
|
relative_bbox->set_width(
|
|
BoundedValue<float>((float)bbox.width() / image_width, 1.0f));
|
|
relative_bbox->set_height(
|
|
BoundedValue<float>((float)bbox.height() / image_height, 1.0f));
|
|
detection->mutable_location_data()->clear_bounding_box();
|
|
detection->mutable_location_data()->set_format(
|
|
LocationData::RELATIVE_BOUNDING_BOX);
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
|
|
const Detection& detection) {
|
|
if (!detection.has_location_data()) {
|
|
return absl::InvalidArgumentError("Detection must have location data.");
|
|
}
|
|
LocationData::Format format = detection.location_data().format();
|
|
RET_CHECK(format == LocationData::RELATIVE_BOUNDING_BOX ||
|
|
format == LocationData::BOUNDING_BOX)
|
|
<< "Detection's location data format must be either "
|
|
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX";
|
|
return format;
|
|
}
|
|
|
|
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
|
|
std::vector<Detection>& detections) {
|
|
RET_CHECK(!detections.empty());
|
|
LocationData::Format output_format;
|
|
ASSIGN_OR_RETURN(output_format, GetLocationDataFormat(detections[0]));
|
|
for (int i = 1; i < detections.size(); ++i) {
|
|
ASSIGN_OR_RETURN(LocationData::Format format,
|
|
GetLocationDataFormat(detections[i]));
|
|
if (output_format != format) {
|
|
return absl::InvalidArgumentError(
|
|
"Input detections have different location data formats.");
|
|
}
|
|
}
|
|
return output_format;
|
|
}
|
|
|
|
absl::Status ConvertBoundingBox(const std::pair<int, int>& image_size,
|
|
Detection* detection) {
|
|
if (!detection->has_location_data()) {
|
|
return absl::InvalidArgumentError("Detection must have location data.");
|
|
}
|
|
switch (detection->location_data().format()) {
|
|
case LocationData::RELATIVE_BOUNDING_BOX:
|
|
return ConvertRelativeBoundingBoxToBoundingBox(image_size, detection);
|
|
case LocationData::BOUNDING_BOX:
|
|
return ConvertBoundingBoxToRelativeBoundingBox(image_size, detection);
|
|
default:
|
|
return absl::InvalidArgumentError(
|
|
"Detection's location data format must be either "
|
|
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX.");
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// Transforms relative bounding box(es) to pixel bounding box(es) in a detection
|
|
// proto/detection list/detection vector, or vice versa.
|
|
//
|
|
// Inputs:
|
|
// One of the following:
|
|
// DETECTION: A Detection proto.
|
|
// DETECTIONS: An std::vector<Detection>/ a DetectionList proto.
|
|
// IMAGE_SIZE: A std::pair<int, int> represention image width and height.
|
|
//
|
|
// Outputs:
|
|
// At least one of the following:
|
|
// PIXEL_DETECTION: A Detection proto with pixel bounding box.
|
|
// PIXEL_DETECTIONS: An std::vector<Detection> with pixel bounding boxes.
|
|
// PIXEL_DETECTION_LIST: A DetectionList proto with pixel bounding boxes.
|
|
// RELATIVE_DETECTION: A Detection proto with relative bounding box.
|
|
// RELATIVE_DETECTIONS: An std::vector<Detection> with relative bounding boxes.
|
|
// RELATIVE_DETECTION_LIST: A DetectionList proto with relative bounding boxes.
|
|
//
|
|
// Example config:
|
|
// For input detection(s) with relative bounding box(es):
|
|
// node {
|
|
// calculator: "DetectionTransformationCalculator"
|
|
// input_stream: "DETECTION:input_detection"
|
|
// input_stream: "IMAGE_SIZE:image_size"
|
|
// output_stream: "PIXEL_DETECTION:output_detection"
|
|
// output_stream: "PIXEL_DETECTIONS:output_detections"
|
|
// output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
|
|
// }
|
|
//
|
|
// For input detection(s) with pixel bounding box(es):
|
|
// node {
|
|
// calculator: "DetectionTransformationCalculator"
|
|
// input_stream: "DETECTION:input_detection"
|
|
// input_stream: "IMAGE_SIZE:image_size"
|
|
// output_stream: "RELATIVE_DETECTION:output_detection"
|
|
// output_stream: "RELATIVE_DETECTIONS:output_detections"
|
|
// output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
|
|
// }
|
|
class DetectionTransformationCalculator : public Node {
|
|
public:
|
|
static constexpr Input<Detection>::Optional kInDetection{"DETECTION"};
|
|
static constexpr Input<OneOf<DetectionList, std::vector<Detection>>>::Optional
|
|
kInDetections{"DETECTIONS"};
|
|
static constexpr Input<std::pair<int, int>> kInImageSize{"IMAGE_SIZE"};
|
|
static constexpr Output<Detection>::Optional kOutPixelDetection{
|
|
"PIXEL_DETECTION"};
|
|
static constexpr Output<std::vector<Detection>>::Optional kOutPixelDetections{
|
|
"PIXEL_DETECTIONS"};
|
|
static constexpr Output<DetectionList>::Optional kOutPixelDetectionList{
|
|
"PIXEL_DETECTION_LIST"};
|
|
static constexpr Output<Detection>::Optional kOutRelativeDetection{
|
|
"RELATIVE_DETECTION"};
|
|
static constexpr Output<std::vector<Detection>>::Optional
|
|
kOutRelativeDetections{"RELATIVE_DETECTIONS"};
|
|
static constexpr Output<DetectionList>::Optional kOutRelativeDetectionList{
|
|
"RELATIVE_DETECTION_LIST"};
|
|
|
|
MEDIAPIPE_NODE_CONTRACT(kInDetection, kInDetections, kInImageSize,
|
|
kOutPixelDetection, kOutPixelDetections,
|
|
kOutPixelDetectionList, kOutRelativeDetection,
|
|
kOutRelativeDetections, kOutRelativeDetectionList);
|
|
|
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
|
RET_CHECK(kInImageSize(cc).IsConnected()) << "Image size must be provided.";
|
|
RET_CHECK(kInDetections(cc).IsConnected() ^ kInDetection(cc).IsConnected());
|
|
if (kInDetections(cc).IsConnected()) {
|
|
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
|
|
kOutPixelDetectionList(cc).IsConnected() ||
|
|
kOutRelativeDetections(cc).IsConnected() ||
|
|
kOutRelativeDetectionList(cc).IsConnected())
|
|
<< "Output must be a container of detections.";
|
|
}
|
|
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
|
|
kOutPixelDetectionList(cc).IsConnected() ||
|
|
kOutPixelDetection(cc).IsConnected() ||
|
|
kOutRelativeDetections(cc).IsConnected() ||
|
|
kOutRelativeDetectionList(cc).IsConnected() ||
|
|
kOutRelativeDetection(cc).IsConnected())
|
|
<< "Must connect at least one output stream.";
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Open(CalculatorContext* cc) override {
|
|
output_pixel_bounding_boxes_ = kOutPixelDetections(cc).IsConnected() ||
|
|
kOutPixelDetectionList(cc).IsConnected() ||
|
|
kOutPixelDetection(cc).IsConnected();
|
|
output_relative_bounding_boxes_ =
|
|
kOutRelativeDetections(cc).IsConnected() ||
|
|
kOutRelativeDetectionList(cc).IsConnected() ||
|
|
kOutRelativeDetection(cc).IsConnected();
|
|
RET_CHECK(output_pixel_bounding_boxes_ ^ output_relative_bounding_boxes_)
|
|
<< "All output streams must have the same stream tag prefix, either "
|
|
"\"PIXEL\" or \"RELATIVE_\".";
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Process(CalculatorContext* cc) override {
|
|
std::pair<int, int> image_size = kInImageSize(cc).Get();
|
|
std::vector<Detection> transformed_detections;
|
|
LocationData::Format input_location_data_format;
|
|
if (kInDetections(cc).IsEmpty() && kInDetection(cc).IsEmpty()) {
|
|
return absl::OkStatus();
|
|
}
|
|
if (kInDetections(cc).IsConnected()) {
|
|
transformed_detections = kInDetections(cc).Visit(
|
|
[&](const DetectionList& detection_list) {
|
|
return std::vector<Detection>(detection_list.detection().begin(),
|
|
detection_list.detection().end());
|
|
},
|
|
[&](const std::vector<Detection>& detection_vector) {
|
|
return detection_vector;
|
|
});
|
|
if (transformed_detections.empty()) {
|
|
OutputEmptyDetections(cc);
|
|
return absl::OkStatus();
|
|
}
|
|
ASSIGN_OR_RETURN(input_location_data_format,
|
|
GetLocationDataFormat(transformed_detections));
|
|
for (Detection& detection : transformed_detections) {
|
|
MP_RETURN_IF_ERROR(ConvertBoundingBox(image_size, &detection));
|
|
}
|
|
} else {
|
|
Detection transformed_detection(kInDetection(cc).Get());
|
|
if (!transformed_detection.has_location_data()) {
|
|
OutputEmptyDetections(cc);
|
|
return absl::OkStatus();
|
|
}
|
|
ASSIGN_OR_RETURN(input_location_data_format,
|
|
GetLocationDataFormat(kInDetection(cc).Get()));
|
|
MP_RETURN_IF_ERROR(
|
|
ConvertBoundingBox(image_size, &transformed_detection));
|
|
transformed_detections.push_back(transformed_detection);
|
|
}
|
|
if (input_location_data_format == LocationData::RELATIVE_BOUNDING_BOX) {
|
|
RET_CHECK(!output_relative_bounding_boxes_)
|
|
<< "Input detections are with relative bounding box(es), and the "
|
|
"output detections must have pixel bounding box(es).";
|
|
if (kOutPixelDetection(cc).IsConnected()) {
|
|
kOutPixelDetection(cc).Send(transformed_detections[0]);
|
|
}
|
|
if (kOutPixelDetections(cc).IsConnected()) {
|
|
kOutPixelDetections(cc).Send(transformed_detections);
|
|
}
|
|
if (kOutPixelDetectionList(cc).IsConnected()) {
|
|
DetectionList detection_list;
|
|
for (const auto& detection : transformed_detections) {
|
|
detection_list.add_detection()->CopyFrom(detection);
|
|
}
|
|
kOutPixelDetectionList(cc).Send(detection_list);
|
|
}
|
|
} else {
|
|
RET_CHECK(!output_pixel_bounding_boxes_)
|
|
<< "Input detections are with pixel bounding box(es), and the "
|
|
"output detections must have relative bounding box(es).";
|
|
if (kOutRelativeDetection(cc).IsConnected()) {
|
|
kOutRelativeDetection(cc).Send(transformed_detections[0]);
|
|
}
|
|
if (kOutRelativeDetections(cc).IsConnected()) {
|
|
kOutRelativeDetections(cc).Send(transformed_detections);
|
|
}
|
|
if (kOutRelativeDetectionList(cc).IsConnected()) {
|
|
DetectionList detection_list;
|
|
for (const auto& detection : transformed_detections) {
|
|
detection_list.add_detection()->CopyFrom(detection);
|
|
}
|
|
kOutRelativeDetectionList(cc).Send(detection_list);
|
|
}
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
private:
|
|
void OutputEmptyDetections(CalculatorContext* cc) {
|
|
if (kOutPixelDetection(cc).IsConnected()) {
|
|
kOutPixelDetection(cc).Send(Detection());
|
|
}
|
|
if (kOutPixelDetections(cc).IsConnected()) {
|
|
kOutPixelDetections(cc).Send(std::vector<Detection>());
|
|
}
|
|
if (kOutPixelDetectionList(cc).IsConnected()) {
|
|
kOutPixelDetectionList(cc).Send(DetectionList());
|
|
}
|
|
if (kOutRelativeDetection(cc).IsConnected()) {
|
|
kOutRelativeDetection(cc).Send(Detection());
|
|
}
|
|
if (kOutRelativeDetections(cc).IsConnected()) {
|
|
kOutRelativeDetections(cc).Send(std::vector<Detection>());
|
|
}
|
|
if (kOutRelativeDetectionList(cc).IsConnected()) {
|
|
kOutRelativeDetectionList(cc).Send(DetectionList());
|
|
}
|
|
}
|
|
|
|
bool output_relative_bounding_boxes_;
|
|
bool output_pixel_bounding_boxes_;
|
|
};
|
|
|
|
MEDIAPIPE_REGISTER_NODE(DetectionTransformationCalculator);
|
|
|
|
} // namespace api2
|
|
} // namespace mediapipe
|