// Copyright 2019 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_ #include #include #include "absl/memory/memory.h" #include "mediapipe/calculators/util/association_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { // Computes the overlap similarity based on Intersection over Union (IoU) of // two rectangles. inline float OverlapSimilarity(const Rectangle_f& rect1, const Rectangle_f& rect2) { if (!rect1.Intersects(rect2)) return 0.0f; // Compute IoU similarity score. const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area(); const float normalization = rect1.Area() + rect2.Area() - intersection_area; return normalization > 0.0f ? intersection_area / normalization : 0.0f; } // AssocationCalculator accepts multiple inputs of vectors of type T that can // be converted to Rectangle_f. The output is a vector of type T that contains // elements from the input vectors that don't overlap with each other. When // two elements overlap, the element that comes in from a later input stream // is kept in the output. This association operation is useful for multiple // instance inference pipelines in MediaPipe. // If an input stream is tagged with "PREV" tag, IDs of overlapping elements // from "PREV" input stream are propagated to the output. Elements in the "PREV" // input stream that don't overlap with other elements are not added to the // output. This stream is designed to take detections from previous timestamp, // e.g. output of PreviousLoopbackCalculator to provide temporal association. // See AssociationDetectionCalculator and AssociationNormRectCalculator for // example uses. template class AssociationCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { // Atmost one input stream can be tagged with "PREV". RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1); if (cc->Inputs().HasTag("PREV")) { RET_CHECK_GE(cc->Inputs().NumEntries(), 2); } for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { cc->Inputs().Get(id).Set>(); } cc->Outputs().Index(0).Set>(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); has_prev_input_stream_ = cc->Inputs().HasTag("PREV"); if (has_prev_input_stream_) { prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0); } options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>(); CHECK_GE(options_.min_similarity_threshold(), 0); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { auto get_non_overlapping_elements = GetNonOverlappingElements(cc); if (!get_non_overlapping_elements.ok()) { return get_non_overlapping_elements.status(); } std::list result = get_non_overlapping_elements.value(); if (has_prev_input_stream_ && !cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) { // Processed all regular input streams. Now compare the result list // elements with those in the PREV input stream, and propagate IDs from // PREV input stream as appropriate. const std::vector& prev_input_vec = cc->Inputs() .Get(prev_input_stream_id_) .template Get>(); MP_RETURN_IF_ERROR( PropagateIdsFromPreviousToCurrent(prev_input_vec, &result)); } auto output = absl::make_unique>(); for (auto it = result.begin(); it != result.end(); ++it) { output->push_back(*it); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); return absl::OkStatus(); } protected: ::mediapipe::AssociationCalculatorOptions options_; bool has_prev_input_stream_; CollectionItemId prev_input_stream_id_; virtual absl::StatusOr GetRectangle(const T& input) { return absl::OkStatus(); } virtual std::pair GetId(const T& input) { return {false, -1}; } virtual void SetId(T* input, int id) {} private: // Get a list of non-overlapping elements from all input streams, with // increasing order of priority based on input stream index. absl::StatusOr> GetNonOverlappingElements( CalculatorContext* cc) { std::list result; // Initialize result with the first non-empty input vector. CollectionItemId non_empty_id = cc->Inputs().BeginId(); for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) { continue; } const std::vector& input_vec = cc->Inputs().Get(id).Get>(); if (!input_vec.empty()) { non_empty_id = id; result.push_back(input_vec[0]); for (int j = 1; j < input_vec.size(); ++j) { MP_RETURN_IF_ERROR(AddElementToList(input_vec[j], &result)); } break; } } // Compare remaining input vectors with the non-empty result vector, // remove lower-priority overlapping elements from the result vector and // had corresponding higher-priority elements as necessary. for (CollectionItemId id = non_empty_id + 1; id < cc->Inputs().EndId(); ++id) { if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) { continue; } const std::vector& input_vec = cc->Inputs().Get(id).Get>(); for (int vi = 0; vi < input_vec.size(); ++vi) { MP_RETURN_IF_ERROR(AddElementToList(input_vec[vi], &result)); } } return result; } absl::Status AddElementToList(T element, std::list* current) { // Compare this element with elements of the input collection. If this // element has high overlap with elements of the collection, remove // those elements from the collection and add this element. ASSIGN_OR_RETURN(auto cur_rect, GetRectangle(element)); bool change_id = false; int new_elem_id = -1; for (auto uit = current->begin(); uit != current->end();) { ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit)); if (OverlapSimilarity(cur_rect, prev_rect) > options_.min_similarity_threshold()) { std::pair prev_id = GetId(*uit); // If prev_id.first is false when some element doesn't have an ID, // change_id and new_elem_id will not be updated. if (prev_id.first) { change_id = prev_id.first; new_elem_id = prev_id.second; } uit = current->erase(uit); } else { ++uit; } } if (change_id) { SetId(&element, new_elem_id); } current->push_back(element); return absl::OkStatus(); } // Compare elements of the current list with elements in from the collection // of elements from the previous input stream, and propagate IDs from the // previous input stream as appropriate. absl::Status PropagateIdsFromPreviousToCurrent( const std::vector& prev_input_vec, std::list* current) { for (auto vit = current->begin(); vit != current->end(); ++vit) { auto get_cur_rectangle = GetRectangle(*vit); if (!get_cur_rectangle.ok()) { return get_cur_rectangle.status(); } const Rectangle_f& cur_rect = get_cur_rectangle.value(); bool change_id = false; int id_for_vi = -1; for (int ui = 0; ui < prev_input_vec.size(); ++ui) { auto get_prev_rectangle = GetRectangle(prev_input_vec[ui]); if (!get_prev_rectangle.ok()) { return get_prev_rectangle.status(); } const Rectangle_f& prev_rect = get_prev_rectangle.value(); if (OverlapSimilarity(cur_rect, prev_rect) > options_.min_similarity_threshold()) { std::pair prev_id = GetId(prev_input_vec[ui]); // If prev_id.first is false when some element doesn't have an ID, // change_id and id_for_vi will not be updated. if (prev_id.first) { change_id = prev_id.first; id_for_vi = prev_id.second; } } } if (change_id) { T element = *vit; SetId(&element, id_for_vi); *vit = element; } } return absl::OkStatus(); } }; } // namespace mediapipe #endif // MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_