260 lines
9.3 KiB
C++
260 lines
9.3 KiB
C++
// Copyright 2019 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
|
|
#define MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "mediapipe/calculators/util/association_calculator.pb.h"
|
|
#include "mediapipe/framework/calculator_context.h"
|
|
#include "mediapipe/framework/calculator_framework.h"
|
|
#include "mediapipe/framework/collection_item_id.h"
|
|
#include "mediapipe/framework/port/canonical_errors.h"
|
|
#include "mediapipe/framework/port/rectangle.h"
|
|
#include "mediapipe/framework/port/status.h"
|
|
|
|
namespace mediapipe {
|
|
|
|
// Computes the overlap similarity based on Intersection over Union (IoU) of
|
|
// two rectangles.
|
|
inline float OverlapSimilarity(const Rectangle_f& rect1,
|
|
const Rectangle_f& rect2) {
|
|
if (!rect1.Intersects(rect2)) return 0.0f;
|
|
// Compute IoU similarity score.
|
|
const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
|
|
const float normalization = rect1.Area() + rect2.Area() - intersection_area;
|
|
return normalization > 0.0f ? intersection_area / normalization : 0.0f;
|
|
}
|
|
|
|
// AssocationCalculator<T> accepts multiple inputs of vectors of type T that can
|
|
// be converted to Rectangle_f. The output is a vector of type T that contains
|
|
// elements from the input vectors that don't overlap with each other. When
|
|
// two elements overlap, the element that comes in from a later input stream
|
|
// is kept in the output. This association operation is useful for multiple
|
|
// instance inference pipelines in MediaPipe.
|
|
// If an input stream is tagged with "PREV" tag, IDs of overlapping elements
|
|
// from "PREV" input stream are propagated to the output. Elements in the "PREV"
|
|
// input stream that don't overlap with other elements are not added to the
|
|
// output. This stream is designed to take detections from previous timestamp,
|
|
// e.g. output of PreviousLoopbackCalculator to provide temporal association.
|
|
// See AssociationDetectionCalculator and AssociationNormRectCalculator for
|
|
// example uses.
|
|
template <typename T>
|
|
class AssociationCalculator : public CalculatorBase {
|
|
public:
|
|
static absl::Status GetContract(CalculatorContract* cc) {
|
|
// Atmost one input stream can be tagged with "PREV".
|
|
RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1);
|
|
|
|
if (cc->Inputs().HasTag("PREV")) {
|
|
RET_CHECK_GE(cc->Inputs().NumEntries(), 2);
|
|
}
|
|
|
|
for (CollectionItemId id = cc->Inputs().BeginId();
|
|
id < cc->Inputs().EndId(); ++id) {
|
|
cc->Inputs().Get(id).Set<std::vector<T>>();
|
|
}
|
|
|
|
cc->Outputs().Index(0).Set<std::vector<T>>();
|
|
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Open(CalculatorContext* cc) override {
|
|
cc->SetOffset(TimestampDiff(0));
|
|
|
|
has_prev_input_stream_ = cc->Inputs().HasTag("PREV");
|
|
if (has_prev_input_stream_) {
|
|
prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0);
|
|
}
|
|
options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>();
|
|
CHECK_GE(options_.min_similarity_threshold(), 0);
|
|
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Process(CalculatorContext* cc) override {
|
|
auto get_non_overlapping_elements = GetNonOverlappingElements(cc);
|
|
if (!get_non_overlapping_elements.ok()) {
|
|
return get_non_overlapping_elements.status();
|
|
}
|
|
std::list<T> result = get_non_overlapping_elements.value();
|
|
|
|
if (has_prev_input_stream_ &&
|
|
!cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) {
|
|
// Processed all regular input streams. Now compare the result list
|
|
// elements with those in the PREV input stream, and propagate IDs from
|
|
// PREV input stream as appropriate.
|
|
const std::vector<T>& prev_input_vec =
|
|
cc->Inputs()
|
|
.Get(prev_input_stream_id_)
|
|
.template Get<std::vector<T>>();
|
|
|
|
MP_RETURN_IF_ERROR(
|
|
PropagateIdsFromPreviousToCurrent(prev_input_vec, &result));
|
|
}
|
|
|
|
auto output = absl::make_unique<std::vector<T>>();
|
|
for (auto it = result.begin(); it != result.end(); ++it) {
|
|
output->push_back(*it);
|
|
}
|
|
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
|
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
protected:
|
|
::mediapipe::AssociationCalculatorOptions options_;
|
|
|
|
bool has_prev_input_stream_;
|
|
CollectionItemId prev_input_stream_id_;
|
|
|
|
virtual absl::StatusOr<Rectangle_f> GetRectangle(const T& input) {
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
virtual std::pair<bool, int> GetId(const T& input) { return {false, -1}; }
|
|
|
|
virtual void SetId(T* input, int id) {}
|
|
|
|
private:
|
|
// Get a list of non-overlapping elements from all input streams, with
|
|
// increasing order of priority based on input stream index.
|
|
absl::StatusOr<std::list<T>> GetNonOverlappingElements(
|
|
CalculatorContext* cc) {
|
|
std::list<T> result;
|
|
|
|
// Initialize result with the first non-empty input vector.
|
|
CollectionItemId non_empty_id = cc->Inputs().BeginId();
|
|
for (CollectionItemId id = cc->Inputs().BeginId();
|
|
id < cc->Inputs().EndId(); ++id) {
|
|
if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
|
|
continue;
|
|
}
|
|
const std::vector<T>& input_vec =
|
|
cc->Inputs().Get(id).Get<std::vector<T>>();
|
|
if (!input_vec.empty()) {
|
|
non_empty_id = id;
|
|
result.push_back(input_vec[0]);
|
|
for (int j = 1; j < input_vec.size(); ++j) {
|
|
MP_RETURN_IF_ERROR(AddElementToList(input_vec[j], &result));
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Compare remaining input vectors with the non-empty result vector,
|
|
// remove lower-priority overlapping elements from the result vector and
|
|
// had corresponding higher-priority elements as necessary.
|
|
for (CollectionItemId id = non_empty_id + 1; id < cc->Inputs().EndId();
|
|
++id) {
|
|
if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
|
|
continue;
|
|
}
|
|
const std::vector<T>& input_vec =
|
|
cc->Inputs().Get(id).Get<std::vector<T>>();
|
|
|
|
for (int vi = 0; vi < input_vec.size(); ++vi) {
|
|
MP_RETURN_IF_ERROR(AddElementToList(input_vec[vi], &result));
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
absl::Status AddElementToList(T element, std::list<T>* current) {
|
|
// Compare this element with elements of the input collection. If this
|
|
// element has high overlap with elements of the collection, remove
|
|
// those elements from the collection and add this element.
|
|
ASSIGN_OR_RETURN(auto cur_rect, GetRectangle(element));
|
|
|
|
bool change_id = false;
|
|
int new_elem_id = -1;
|
|
|
|
for (auto uit = current->begin(); uit != current->end();) {
|
|
ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit));
|
|
if (OverlapSimilarity(cur_rect, prev_rect) >
|
|
options_.min_similarity_threshold()) {
|
|
std::pair<bool, int> prev_id = GetId(*uit);
|
|
// If prev_id.first is false when some element doesn't have an ID,
|
|
// change_id and new_elem_id will not be updated.
|
|
if (prev_id.first) {
|
|
change_id = prev_id.first;
|
|
new_elem_id = prev_id.second;
|
|
}
|
|
uit = current->erase(uit);
|
|
} else {
|
|
++uit;
|
|
}
|
|
}
|
|
|
|
if (change_id) {
|
|
SetId(&element, new_elem_id);
|
|
}
|
|
current->push_back(element);
|
|
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
// Compare elements of the current list with elements in from the collection
|
|
// of elements from the previous input stream, and propagate IDs from the
|
|
// previous input stream as appropriate.
|
|
absl::Status PropagateIdsFromPreviousToCurrent(
|
|
const std::vector<T>& prev_input_vec, std::list<T>* current) {
|
|
for (auto vit = current->begin(); vit != current->end(); ++vit) {
|
|
auto get_cur_rectangle = GetRectangle(*vit);
|
|
if (!get_cur_rectangle.ok()) {
|
|
return get_cur_rectangle.status();
|
|
}
|
|
const Rectangle_f& cur_rect = get_cur_rectangle.value();
|
|
|
|
bool change_id = false;
|
|
int id_for_vi = -1;
|
|
|
|
for (int ui = 0; ui < prev_input_vec.size(); ++ui) {
|
|
auto get_prev_rectangle = GetRectangle(prev_input_vec[ui]);
|
|
if (!get_prev_rectangle.ok()) {
|
|
return get_prev_rectangle.status();
|
|
}
|
|
const Rectangle_f& prev_rect = get_prev_rectangle.value();
|
|
|
|
if (OverlapSimilarity(cur_rect, prev_rect) >
|
|
options_.min_similarity_threshold()) {
|
|
std::pair<bool, int> prev_id = GetId(prev_input_vec[ui]);
|
|
// If prev_id.first is false when some element doesn't have an ID,
|
|
// change_id and id_for_vi will not be updated.
|
|
if (prev_id.first) {
|
|
change_id = prev_id.first;
|
|
id_for_vi = prev_id.second;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (change_id) {
|
|
T element = *vit;
|
|
SetId(&element, id_for_vi);
|
|
*vit = element;
|
|
}
|
|
}
|
|
return absl::OkStatus();
|
|
}
|
|
};
|
|
|
|
} // namespace mediapipe
|
|
|
|
#endif // MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
|