// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/location_data.pb.h" namespace mediapipe { namespace api2 { namespace { template T BoundedValue(T value, T upper_bound) { T output = std::min(value, upper_bound); if (output < 0) { return 0; } return output; } absl::Status ConvertRelativeBoundingBoxToBoundingBox( const std::pair& image_size, Detection* detection) { const int image_width = image_size.first; const int image_height = image_size.second; const auto& relative_bbox = detection->location_data().relative_bounding_box(); auto* bbox = detection->mutable_location_data()->mutable_bounding_box(); bbox->set_xmin( BoundedValue(relative_bbox.xmin() * image_width, image_width)); bbox->set_ymin( BoundedValue(relative_bbox.ymin() * image_height, image_height)); bbox->set_width( BoundedValue(relative_bbox.width() * image_width, image_width)); bbox->set_height( BoundedValue(relative_bbox.height() * image_height, image_height)); detection->mutable_location_data()->set_format(LocationData::BOUNDING_BOX); detection->mutable_location_data()->clear_relative_bounding_box(); return absl::OkStatus(); } absl::Status ConvertBoundingBoxToRelativeBoundingBox( const std::pair& image_size, Detection* detection) { int image_width = image_size.first; int image_height = image_size.second; const auto& bbox = detection->location_data().bounding_box(); auto* relative_bbox = detection->mutable_location_data()->mutable_relative_bounding_box(); relative_bbox->set_xmin( BoundedValue((float)bbox.xmin() / image_width, 1.0f)); relative_bbox->set_ymin( BoundedValue((float)bbox.ymin() / image_height, 1.0f)); relative_bbox->set_width( BoundedValue((float)bbox.width() / image_width, 1.0f)); relative_bbox->set_height( BoundedValue((float)bbox.height() / image_height, 1.0f)); detection->mutable_location_data()->clear_bounding_box(); detection->mutable_location_data()->set_format( LocationData::RELATIVE_BOUNDING_BOX); return absl::OkStatus(); } absl::StatusOr GetLocationDataFormat( const Detection& detection) { if (!detection.has_location_data()) { return absl::InvalidArgumentError("Detection must have location data."); } LocationData::Format format = detection.location_data().format(); RET_CHECK(format == LocationData::RELATIVE_BOUNDING_BOX || format == LocationData::BOUNDING_BOX) << "Detection's location data format must be either " "RELATIVE_BOUNDING_BOX or BOUNDING_BOX"; return format; } absl::StatusOr GetLocationDataFormat( std::vector& detections) { RET_CHECK(!detections.empty()); LocationData::Format output_format; ASSIGN_OR_RETURN(output_format, GetLocationDataFormat(detections[0])); for (int i = 1; i < detections.size(); ++i) { ASSIGN_OR_RETURN(LocationData::Format format, GetLocationDataFormat(detections[i])); if (output_format != format) { return absl::InvalidArgumentError( "Input detections have different location data formats."); } } return output_format; } absl::Status ConvertBoundingBox(const std::pair& image_size, Detection* detection) { if (!detection->has_location_data()) { return absl::InvalidArgumentError("Detection must have location data."); } switch (detection->location_data().format()) { case LocationData::RELATIVE_BOUNDING_BOX: return ConvertRelativeBoundingBoxToBoundingBox(image_size, detection); case LocationData::BOUNDING_BOX: return ConvertBoundingBoxToRelativeBoundingBox(image_size, detection); default: return absl::InvalidArgumentError( "Detection's location data format must be either " "RELATIVE_BOUNDING_BOX or BOUNDING_BOX."); } } } // namespace // Transforms relative bounding box(es) to pixel bounding box(es) in a detection // proto/detection list/detection vector, or vice versa. // // Inputs: // One of the following: // DETECTION: A Detection proto. // DETECTIONS: An std::vector/ a DetectionList proto. // IMAGE_SIZE: A std::pair represention image width and height. // // Outputs: // At least one of the following: // PIXEL_DETECTION: A Detection proto with pixel bounding box. // PIXEL_DETECTIONS: An std::vector with pixel bounding boxes. // PIXEL_DETECTION_LIST: A DetectionList proto with pixel bounding boxes. // RELATIVE_DETECTION: A Detection proto with relative bounding box. // RELATIVE_DETECTIONS: An std::vector with relative bounding boxes. // RELATIVE_DETECTION_LIST: A DetectionList proto with relative bounding boxes. // // Example config: // For input detection(s) with relative bounding box(es): // node { // calculator: "DetectionTransformationCalculator" // input_stream: "DETECTION:input_detection" // input_stream: "IMAGE_SIZE:image_size" // output_stream: "PIXEL_DETECTION:output_detection" // output_stream: "PIXEL_DETECTIONS:output_detections" // output_stream: "PIXEL_DETECTION_LIST:output_detection_list" // } // // For input detection(s) with pixel bounding box(es): // node { // calculator: "DetectionTransformationCalculator" // input_stream: "DETECTION:input_detection" // input_stream: "IMAGE_SIZE:image_size" // output_stream: "RELATIVE_DETECTION:output_detection" // output_stream: "RELATIVE_DETECTIONS:output_detections" // output_stream: "RELATIVE_DETECTION_LIST:output_detection_list" // } class DetectionTransformationCalculator : public Node { public: static constexpr Input::Optional kInDetection{"DETECTION"}; static constexpr Input>>::Optional kInDetections{"DETECTIONS"}; static constexpr Input> kInImageSize{"IMAGE_SIZE"}; static constexpr Output::Optional kOutPixelDetection{ "PIXEL_DETECTION"}; static constexpr Output>::Optional kOutPixelDetections{ "PIXEL_DETECTIONS"}; static constexpr Output::Optional kOutPixelDetectionList{ "PIXEL_DETECTION_LIST"}; static constexpr Output::Optional kOutRelativeDetection{ "RELATIVE_DETECTION"}; static constexpr Output>::Optional kOutRelativeDetections{"RELATIVE_DETECTIONS"}; static constexpr Output::Optional kOutRelativeDetectionList{ "RELATIVE_DETECTION_LIST"}; MEDIAPIPE_NODE_CONTRACT(kInDetection, kInDetections, kInImageSize, kOutPixelDetection, kOutPixelDetections, kOutPixelDetectionList, kOutRelativeDetection, kOutRelativeDetections, kOutRelativeDetectionList); static absl::Status UpdateContract(CalculatorContract* cc) { RET_CHECK(kInImageSize(cc).IsConnected()) << "Image size must be provided."; RET_CHECK(kInDetections(cc).IsConnected() ^ kInDetection(cc).IsConnected()); if (kInDetections(cc).IsConnected()) { RET_CHECK(kOutPixelDetections(cc).IsConnected() || kOutPixelDetectionList(cc).IsConnected() || kOutRelativeDetections(cc).IsConnected() || kOutRelativeDetectionList(cc).IsConnected()) << "Output must be a container of detections."; } RET_CHECK(kOutPixelDetections(cc).IsConnected() || kOutPixelDetectionList(cc).IsConnected() || kOutPixelDetection(cc).IsConnected() || kOutRelativeDetections(cc).IsConnected() || kOutRelativeDetectionList(cc).IsConnected() || kOutRelativeDetection(cc).IsConnected()) << "Must connect at least one output stream."; return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { output_pixel_bounding_boxes_ = kOutPixelDetections(cc).IsConnected() || kOutPixelDetectionList(cc).IsConnected() || kOutPixelDetection(cc).IsConnected(); output_relative_bounding_boxes_ = kOutRelativeDetections(cc).IsConnected() || kOutRelativeDetectionList(cc).IsConnected() || kOutRelativeDetection(cc).IsConnected(); RET_CHECK(output_pixel_bounding_boxes_ ^ output_relative_bounding_boxes_) << "All output streams must have the same stream tag prefix, either " "\"PIXEL\" or \"RELATIVE_\"."; return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { std::pair image_size = kInImageSize(cc).Get(); std::vector transformed_detections; LocationData::Format input_location_data_format; if (kInDetections(cc).IsEmpty() && kInDetection(cc).IsEmpty()) { return absl::OkStatus(); } if (kInDetections(cc).IsConnected()) { transformed_detections = kInDetections(cc).Visit( [&](const DetectionList& detection_list) { return std::vector(detection_list.detection().begin(), detection_list.detection().end()); }, [&](const std::vector& detection_vector) { return detection_vector; }); if (transformed_detections.empty()) { OutputEmptyDetections(cc); return absl::OkStatus(); } ASSIGN_OR_RETURN(input_location_data_format, GetLocationDataFormat(transformed_detections)); for (Detection& detection : transformed_detections) { MP_RETURN_IF_ERROR(ConvertBoundingBox(image_size, &detection)); } } else { Detection transformed_detection(kInDetection(cc).Get()); if (!transformed_detection.has_location_data()) { OutputEmptyDetections(cc); return absl::OkStatus(); } ASSIGN_OR_RETURN(input_location_data_format, GetLocationDataFormat(kInDetection(cc).Get())); MP_RETURN_IF_ERROR( ConvertBoundingBox(image_size, &transformed_detection)); transformed_detections.push_back(transformed_detection); } if (input_location_data_format == LocationData::RELATIVE_BOUNDING_BOX) { RET_CHECK(!output_relative_bounding_boxes_) << "Input detections are with relative bounding box(es), and the " "output detections must have pixel bounding box(es)."; if (kOutPixelDetection(cc).IsConnected()) { kOutPixelDetection(cc).Send(transformed_detections[0]); } if (kOutPixelDetections(cc).IsConnected()) { kOutPixelDetections(cc).Send(transformed_detections); } if (kOutPixelDetectionList(cc).IsConnected()) { DetectionList detection_list; for (const auto& detection : transformed_detections) { detection_list.add_detection()->CopyFrom(detection); } kOutPixelDetectionList(cc).Send(detection_list); } } else { RET_CHECK(!output_pixel_bounding_boxes_) << "Input detections are with pixel bounding box(es), and the " "output detections must have relative bounding box(es)."; if (kOutRelativeDetection(cc).IsConnected()) { kOutRelativeDetection(cc).Send(transformed_detections[0]); } if (kOutRelativeDetections(cc).IsConnected()) { kOutRelativeDetections(cc).Send(transformed_detections); } if (kOutRelativeDetectionList(cc).IsConnected()) { DetectionList detection_list; for (const auto& detection : transformed_detections) { detection_list.add_detection()->CopyFrom(detection); } kOutRelativeDetectionList(cc).Send(detection_list); } } return absl::OkStatus(); } private: void OutputEmptyDetections(CalculatorContext* cc) { if (kOutPixelDetection(cc).IsConnected()) { kOutPixelDetection(cc).Send(Detection()); } if (kOutPixelDetections(cc).IsConnected()) { kOutPixelDetections(cc).Send(std::vector()); } if (kOutPixelDetectionList(cc).IsConnected()) { kOutPixelDetectionList(cc).Send(DetectionList()); } if (kOutRelativeDetection(cc).IsConnected()) { kOutRelativeDetection(cc).Send(Detection()); } if (kOutRelativeDetections(cc).IsConnected()) { kOutRelativeDetections(cc).Send(std::vector()); } if (kOutRelativeDetectionList(cc).IsConnected()) { kOutRelativeDetectionList(cc).Send(DetectionList()); } } bool output_relative_bounding_boxes_; bool output_pixel_bounding_boxes_; }; MEDIAPIPE_REGISTER_NODE(DetectionTransformationCalculator); } // namespace api2 } // namespace mediapipe