150 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			150 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright 2021 The MediaPipe Authors.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //      http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #include "absl/strings/substitute.h"
 | |
| #include "mediapipe/framework/api2/node.h"
 | |
| #include "mediapipe/framework/calculator_framework.h"
 | |
| #include "mediapipe/framework/formats/classification.pb.h"
 | |
| #include "mediapipe/framework/formats/detection.pb.h"
 | |
| #include "mediapipe/framework/port/ret_check.h"
 | |
| #include "mediapipe/framework/port/status.h"
 | |
| #include "mediapipe/framework/port/statusor.h"
 | |
| 
 | |
| namespace mediapipe {
 | |
| namespace api2 {
 | |
| 
 | |
| namespace {}  // namespace
 | |
| 
 | |
| // Replaces the classification labels and scores from the input `Detection` with
 | |
| // the ones provided into the input `ClassificationList`. Namely:
 | |
| // * `label_id[i]` becomes `classification[i].index`
 | |
| // * `score[i]` becomes `classification[i].score`
 | |
| // * `label[i]` becomes `classification[i].label` (if present)
 | |
| //
 | |
| // In case the input `ClassificationList` contains no results (i.e.
 | |
| // `classification` is empty, which may happen if the classifier uses a score
 | |
| // threshold and no confident enough result were returned), the input
 | |
| // `Detection` is returned unchanged.
 | |
| //
 | |
| // This is specifically designed for two-stage detection cascades where the
 | |
| // detections returned by a standalone detector (typically a class-agnostic
 | |
| // localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through
 | |
| // the optional "RECT" or "NORM_RECT" input, e.g:
 | |
| //
 | |
| // node {
 | |
| //   calculator: "DetectionsToRectsCalculator"
 | |
| //   # Output of an upstream object detector.
 | |
| //   input_stream: "DETECTION:detection"
 | |
| //   output_stream: "NORM_RECT:norm_rect"
 | |
| // }
 | |
| // node {
 | |
| //   calculator: "TfLiteTaskImageClassifierCalculator"
 | |
| //   input_stream: "IMAGE:image"
 | |
| //   input_stream: "NORM_RECT:norm_rect"
 | |
| //   output_stream: "CLASSIFICATION_RESULT:classification_result"
 | |
| // }
 | |
| // node {
 | |
| //   calculator: "TfLiteTaskClassificationResultToClassificationsCalculator"
 | |
| //   input_stream: "CLASSIFICATION_RESULT:classification_result"
 | |
| //   output_stream: "CLASSIFICATION_LIST:classification_list"
 | |
| // }
 | |
| // node {
 | |
| //   calculator: "DetectionClassificationsMergerCalculator"
 | |
| //   input_stream: "INPUT_DETECTION:detection"
 | |
| //   input_stream: "CLASSIFICATION_LIST:classification_list"
 | |
| //   # Final output.
 | |
| //   output_stream: "OUTPUT_DETECTION:classified_detection"
 | |
| // }
 | |
| //
 | |
| // Inputs:
 | |
| // INPUT_DETECTION: `Detection` proto.
 | |
| // CLASSIFICATION_LIST: `ClassificationList` proto.
 | |
| //
 | |
| // Output:
 | |
| // OUTPUT_DETECTION: modified `Detection` proto.
 | |
| class DetectionClassificationsMergerCalculator : public Node {
 | |
|  public:
 | |
|   static constexpr Input<Detection> kInputDetection{"INPUT_DETECTION"};
 | |
|   static constexpr Input<ClassificationList> kClassificationList{
 | |
|       "CLASSIFICATION_LIST"};
 | |
|   static constexpr Output<Detection> kOutputDetection{"OUTPUT_DETECTION"};
 | |
| 
 | |
|   MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList,
 | |
|                           kOutputDetection);
 | |
| 
 | |
|   absl::Status Process(CalculatorContext* cc) override;
 | |
| };
 | |
| MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator);
 | |
| 
 | |
| absl::Status DetectionClassificationsMergerCalculator::Process(
 | |
|     CalculatorContext* cc) {
 | |
|   if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) {
 | |
|     return absl::OkStatus();
 | |
|   }
 | |
|   RET_CHECK(!kInputDetection(cc).IsEmpty());
 | |
|   RET_CHECK(!kClassificationList(cc).IsEmpty());
 | |
| 
 | |
|   Detection detection = *kInputDetection(cc);
 | |
|   const ClassificationList& classification_list = *kClassificationList(cc);
 | |
| 
 | |
|   // Update input detection only if classification did return results.
 | |
|   if (classification_list.classification_size() != 0) {
 | |
|     detection.clear_label_id();
 | |
|     detection.clear_score();
 | |
|     detection.clear_label();
 | |
|     detection.clear_display_name();
 | |
|     for (const auto& classification : classification_list.classification()) {
 | |
|       if (!classification.has_index()) {
 | |
|         return absl::InvalidArgumentError(
 | |
|             "Missing required 'index' field in Classification proto.");
 | |
|       }
 | |
|       detection.add_label_id(classification.index());
 | |
|       if (!classification.has_score()) {
 | |
|         return absl::InvalidArgumentError(
 | |
|             "Missing required 'score' field in Classification proto.");
 | |
|       }
 | |
|       detection.add_score(classification.score());
 | |
|       if (classification.has_label()) {
 | |
|         detection.add_label(classification.label());
 | |
|       }
 | |
|       if (classification.has_display_name()) {
 | |
|         detection.add_display_name(classification.display_name());
 | |
|       }
 | |
|     }
 | |
|     // Post-conversion sanity checks.
 | |
|     if (detection.label_size() != 0 &&
 | |
|         detection.label_size() != detection.label_id_size()) {
 | |
|       return absl::InvalidArgumentError(absl::Substitute(
 | |
|           "Each input Classification is expected to either always or never "
 | |
|           "provide a 'label' field. Found $0 'label' fields for $1 "
 | |
|           "'Classification' objects.",
 | |
|           /*$0=*/detection.label_size(), /*$1=*/detection.label_id_size()));
 | |
|     }
 | |
|     if (detection.display_name_size() != 0 &&
 | |
|         detection.display_name_size() != detection.label_id_size()) {
 | |
|       return absl::InvalidArgumentError(absl::Substitute(
 | |
|           "Each input Classification is expected to either always or never "
 | |
|           "provide a 'display_name' field. Found $0 'display_name' fields for "
 | |
|           "$1 'Classification' objects.",
 | |
|           /*$0=*/detection.display_name_size(),
 | |
|           /*$1=*/detection.label_id_size()));
 | |
|     }
 | |
|   }
 | |
|   kOutputDetection(cc).Send(detection);
 | |
|   return absl::OkStatus();
 | |
| }
 | |
| 
 | |
| }  // namespace api2
 | |
| }  // namespace mediapipe
 |