150 lines
5.9 KiB
C++
150 lines
5.9 KiB
C++
// Copyright 2021 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "absl/strings/substitute.h"
|
|
#include "mediapipe/framework/api2/node.h"
|
|
#include "mediapipe/framework/calculator_framework.h"
|
|
#include "mediapipe/framework/formats/classification.pb.h"
|
|
#include "mediapipe/framework/formats/detection.pb.h"
|
|
#include "mediapipe/framework/port/ret_check.h"
|
|
#include "mediapipe/framework/port/status.h"
|
|
#include "mediapipe/framework/port/statusor.h"
|
|
|
|
namespace mediapipe {
|
|
namespace api2 {
|
|
|
|
namespace {} // namespace
|
|
|
|
// Replaces the classification labels and scores from the input `Detection` with
|
|
// the ones provided into the input `ClassificationList`. Namely:
|
|
// * `label_id[i]` becomes `classification[i].index`
|
|
// * `score[i]` becomes `classification[i].score`
|
|
// * `label[i]` becomes `classification[i].label` (if present)
|
|
//
|
|
// In case the input `ClassificationList` contains no results (i.e.
|
|
// `classification` is empty, which may happen if the classifier uses a score
|
|
// threshold and no confident enough result were returned), the input
|
|
// `Detection` is returned unchanged.
|
|
//
|
|
// This is specifically designed for two-stage detection cascades where the
|
|
// detections returned by a standalone detector (typically a class-agnostic
|
|
// localizer) are fed e.g. into a `TfLiteTaskImageClassifierCalculator` through
|
|
// the optional "RECT" or "NORM_RECT" input, e.g:
|
|
//
|
|
// node {
|
|
// calculator: "DetectionsToRectsCalculator"
|
|
// # Output of an upstream object detector.
|
|
// input_stream: "DETECTION:detection"
|
|
// output_stream: "NORM_RECT:norm_rect"
|
|
// }
|
|
// node {
|
|
// calculator: "TfLiteTaskImageClassifierCalculator"
|
|
// input_stream: "IMAGE:image"
|
|
// input_stream: "NORM_RECT:norm_rect"
|
|
// output_stream: "CLASSIFICATION_RESULT:classification_result"
|
|
// }
|
|
// node {
|
|
// calculator: "TfLiteTaskClassificationResultToClassificationsCalculator"
|
|
// input_stream: "CLASSIFICATION_RESULT:classification_result"
|
|
// output_stream: "CLASSIFICATION_LIST:classification_list"
|
|
// }
|
|
// node {
|
|
// calculator: "DetectionClassificationsMergerCalculator"
|
|
// input_stream: "INPUT_DETECTION:detection"
|
|
// input_stream: "CLASSIFICATION_LIST:classification_list"
|
|
// # Final output.
|
|
// output_stream: "OUTPUT_DETECTION:classified_detection"
|
|
// }
|
|
//
|
|
// Inputs:
|
|
// INPUT_DETECTION: `Detection` proto.
|
|
// CLASSIFICATION_LIST: `ClassificationList` proto.
|
|
//
|
|
// Output:
|
|
// OUTPUT_DETECTION: modified `Detection` proto.
|
|
class DetectionClassificationsMergerCalculator : public Node {
|
|
public:
|
|
static constexpr Input<Detection> kInputDetection{"INPUT_DETECTION"};
|
|
static constexpr Input<ClassificationList> kClassificationList{
|
|
"CLASSIFICATION_LIST"};
|
|
static constexpr Output<Detection> kOutputDetection{"OUTPUT_DETECTION"};
|
|
|
|
MEDIAPIPE_NODE_CONTRACT(kInputDetection, kClassificationList,
|
|
kOutputDetection);
|
|
|
|
absl::Status Process(CalculatorContext* cc) override;
|
|
};
|
|
MEDIAPIPE_REGISTER_NODE(DetectionClassificationsMergerCalculator);
|
|
|
|
absl::Status DetectionClassificationsMergerCalculator::Process(
|
|
CalculatorContext* cc) {
|
|
if (kInputDetection(cc).IsEmpty() && kClassificationList(cc).IsEmpty()) {
|
|
return absl::OkStatus();
|
|
}
|
|
RET_CHECK(!kInputDetection(cc).IsEmpty());
|
|
RET_CHECK(!kClassificationList(cc).IsEmpty());
|
|
|
|
Detection detection = *kInputDetection(cc);
|
|
const ClassificationList& classification_list = *kClassificationList(cc);
|
|
|
|
// Update input detection only if classification did return results.
|
|
if (classification_list.classification_size() != 0) {
|
|
detection.clear_label_id();
|
|
detection.clear_score();
|
|
detection.clear_label();
|
|
detection.clear_display_name();
|
|
for (const auto& classification : classification_list.classification()) {
|
|
if (!classification.has_index()) {
|
|
return absl::InvalidArgumentError(
|
|
"Missing required 'index' field in Classification proto.");
|
|
}
|
|
detection.add_label_id(classification.index());
|
|
if (!classification.has_score()) {
|
|
return absl::InvalidArgumentError(
|
|
"Missing required 'score' field in Classification proto.");
|
|
}
|
|
detection.add_score(classification.score());
|
|
if (classification.has_label()) {
|
|
detection.add_label(classification.label());
|
|
}
|
|
if (classification.has_display_name()) {
|
|
detection.add_display_name(classification.display_name());
|
|
}
|
|
}
|
|
// Post-conversion sanity checks.
|
|
if (detection.label_size() != 0 &&
|
|
detection.label_size() != detection.label_id_size()) {
|
|
return absl::InvalidArgumentError(absl::Substitute(
|
|
"Each input Classification is expected to either always or never "
|
|
"provide a 'label' field. Found $0 'label' fields for $1 "
|
|
"'Classification' objects.",
|
|
/*$0=*/detection.label_size(), /*$1=*/detection.label_id_size()));
|
|
}
|
|
if (detection.display_name_size() != 0 &&
|
|
detection.display_name_size() != detection.label_id_size()) {
|
|
return absl::InvalidArgumentError(absl::Substitute(
|
|
"Each input Classification is expected to either always or never "
|
|
"provide a 'display_name' field. Found $0 'display_name' fields for "
|
|
"$1 'Classification' objects.",
|
|
/*$0=*/detection.display_name_size(),
|
|
/*$1=*/detection.label_id_size()));
|
|
}
|
|
}
|
|
kOutputDetection(cc).Send(detection);
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
} // namespace api2
|
|
} // namespace mediapipe
|