mediapipe-rs/mediapipe/modules/objectron/calculators/frame_annotation_tracker.h
2022-06-11 12:25:48 -07:00

63 lines
2.3 KiB
C++

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_MODULES_OBJECTRON_CALCULATORS_FRAME_ANNOTATION_TRACKER_H_
#define MEDIAPIPE_MODULES_OBJECTRON_CALCULATORS_FRAME_ANNOTATION_TRACKER_H_
#include <functional>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_set.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/modules/objectron/calculators/annotation_data.pb.h"
#include "mediapipe/util/tracking/box_tracker.pb.h"
namespace mediapipe {
class FrameAnnotationTracker {
public:
// If two bounding boxes have IoU over iou_threshold, then we consider them
// describing the same object.
FrameAnnotationTracker(float iou_threshold, float img_width, float img_height)
: iou_threshold_(iou_threshold),
img_width_(img_width),
img_height_(img_height) {}
// Adds detection results from an external detector.
void AddDetectionResult(const FrameAnnotation& frame_annotation);
// Consolidates tracking result from an external tracker, associates with
// the detection result by the object id, and produces the corresponding
// result in FrameAnnotation. When there are duplicates, output the ids that
// need to be cancelled in cancel_object_ids.
// Note that the returned FrameAnnotation is missing timestamp. Need to fill
// that field.
FrameAnnotation ConsolidateTrackingResult(
const TimedBoxProtoList& tracked_boxes,
absl::flat_hash_set<int>* cancel_object_ids);
private:
float iou_threshold_;
float img_width_;
float img_height_;
// Cached detection results over time.
// Key is timestamp_us + object_id.
absl::btree_map<int64, ObjectAnnotation, std::greater<int64>>
detected_objects_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_MODULES_OBJECTRON_CALCULATORS_FRAME_ANNOTATION_TRACKER_H_