// Copyright 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "absl/strings/substitute.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" namespace mediapipe { namespace api2 { namespace {} // namespace // Replaces the classification labels and scores from the input `Detection` with // the ones provided into the input `ClassificationList`. Namely: // * `label_id[i]` becomes `classification[i].index` // * `score[i]` becomes `classification[i].score` // * `label[i]` becomes `classification[i].label` (if present) // // In case the input `ClassificationList` contains no results (i.e. // `classification` is empty, which may happen if the classifier uses a score // threshold and no confident enough result were returned), the input // `Detection` is returned unchanged. // // This is specifically designed for two-stage detection cascades where the // detections returned by a standalone detector (typically a class-agnostic // localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through // the optional "RECT" or "NORM_RECT" input, e.g: // // node { // calculator: "DetectionsToRectsCalculator" // # Output of an upstream object detector. // input_stream: "DETECTION:detection" // output_stream: "NORM_RECT:norm_rect" // } // node { // calculator: "TfLiteTaskImageClassifierCalculator" // input_stream: "IMAGE:image" // input_stream: "NORM_RECT:norm_rect" // output_stream: "CLASSIFICATION_RESULT:classification_result" // } // node { // calculator: "TfLiteTaskClassificationResultToClassificationsCalculator" // input_stream: "CLASSIFICATION_RESULT:classification_result" // output_stream: "CLASSIFICATION_LIST:classification_list" // } // node { // calculator: "DetectionClassificationsMergerCalculator" // input_stream: "INPUT_DETECTION:detection" // input_stream: "CLASSIFICATION_LIST:classification_list" // # Final output. // output_stream: "OUTPUT_DETECTION:classified_detection" // } // // Inputs: // INPUT_DETECTION: `Detection` proto. // CLASSIFICATION_LIST: `ClassificationList` proto. // // Output: // OUTPUT_DETECTION: modified `Detection` proto. class DetectionClassificationsMergerCalculator : public Node { public: static constexpr Input kInputDetection{"INPUT_DETECTION"}; static constexpr Input kClassificationList{ "CLASSIFICATION_LIST"}; static constexpr Output kOutputDetection{"OUTPUT_DETECTION"}; MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList, kOutputDetection); absl::Status Process(CalculatorContext* cc) override; }; MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator); absl::Status DetectionClassificationsMergerCalculator::Process( CalculatorContext* cc) { if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) { return absl::OkStatus(); } RET_CHECK(!kInputDetection(cc).IsEmpty()); RET_CHECK(!kClassificationList(cc).IsEmpty()); Detection detection = *kInputDetection(cc); const ClassificationList& classification_list = *kClassificationList(cc); // Update input detection only if classification did return results. if (classification_list.classification_size() != 0) { detection.clear_label_id(); detection.clear_score(); detection.clear_label(); detection.clear_display_name(); for (const auto& classification : classification_list.classification()) { if (!classification.has_index()) { return absl::InvalidArgumentError( "Missing required 'index' field in Classification proto."); } detection.add_label_id(classification.index()); if (!classification.has_score()) { return absl::InvalidArgumentError( "Missing required 'score' field in Classification proto."); } detection.add_score(classification.score()); if (classification.has_label()) { detection.add_label(classification.label()); } if (classification.has_display_name()) { detection.add_display_name(classification.display_name()); } } // Post-conversion sanity checks. if (detection.label_size() != 0 && detection.label_size() != detection.label_id_size()) { return absl::InvalidArgumentError(absl::Substitute( "Each input Classification is expected to either always or never " "provide a 'label' field. Found $0 'label' fields for $1 " "'Classification' objects.", /*$0=*/detection.label_size(), /*$1=*/detection.label_id_size())); } if (detection.display_name_size() != 0 && detection.display_name_size() != detection.label_id_size()) { return absl::InvalidArgumentError(absl::Substitute( "Each input Classification is expected to either always or never " "provide a 'display_name' field. Found $0 'display_name' fields for " "$1 'Classification' objects.", /*$0=*/detection.display_name_size(), /*$1=*/detection.label_id_size())); } } kOutputDetection(cc).Send(detection); return absl::OkStatus(); } } // namespace api2 } // namespace mediapipe