mediapipe-rs/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc
2022-06-11 12:25:48 -07:00

103 lines
3.7 KiB
C++

// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/modules/objectron/calculators/frame_annotation_tracker.h"
#include "absl/container/flat_hash_set.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/modules/objectron/calculators/annotation_data.pb.h"
#include "mediapipe/modules/objectron/calculators/box_util.h"
#include "mediapipe/util/tracking/box_tracker.pb.h"
namespace mediapipe {
void FrameAnnotationTracker::AddDetectionResult(
const FrameAnnotation& frame_annotation) {
const int64 time_us =
static_cast<int64>(std::round(frame_annotation.timestamp()));
for (const auto& object_annotation : frame_annotation.annotations()) {
detected_objects_[time_us + object_annotation.object_id()] =
object_annotation;
}
}
FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
const TimedBoxProtoList& tracked_boxes,
absl::flat_hash_set<int>* cancel_object_ids) {
CHECK(cancel_object_ids != nullptr);
FrameAnnotation frame_annotation;
std::vector<int64> keys_to_be_deleted;
for (const auto& detected_obj : detected_objects_) {
const int object_id = detected_obj.second.object_id();
if (cancel_object_ids->contains(object_id)) {
// Remember duplicated detections' keys.
keys_to_be_deleted.push_back(detected_obj.first);
continue;
}
TimedBoxProto ref_box;
for (const auto& box : tracked_boxes.box()) {
if (box.id() == object_id) {
ref_box = box;
break;
}
}
if (!ref_box.has_id() || ref_box.id() < 0) {
LOG(ERROR) << "Can't find matching tracked box for object id: "
<< object_id << ". Likely lost tracking of it.";
keys_to_be_deleted.push_back(detected_obj.first);
continue;
}
// Find duplicated boxes
for (const auto& box : tracked_boxes.box()) {
if (box.id() != object_id) {
if (ComputeBoxIoU(ref_box, box) > iou_threshold_) {
cancel_object_ids->insert(box.id());
}
}
}
// Map ObjectAnnotation from detection to tracked time.
// First, gather all keypoints from source detection.
std::vector<cv::Point2f> key_points;
for (const auto& keypoint : detected_obj.second.keypoints()) {
key_points.push_back(
cv::Point2f(keypoint.point_2d().x(), keypoint.point_2d().y()));
}
// Second, find source box.
TimedBoxProto src_box;
ComputeBoundingRect(key_points, &src_box);
ObjectAnnotation* tracked_obj = frame_annotation.add_annotations();
tracked_obj->set_object_id(ref_box.id());
// Finally, map all keypoints in the source detection to tracked location.
for (const auto& keypoint : detected_obj.second.keypoints()) {
cv::Point2f dst = MapPoint(
src_box, ref_box,
cv::Point2f(keypoint.point_2d().x(), keypoint.point_2d().y()),
img_width_, img_height_);
auto* dst_point = tracked_obj->add_keypoints()->mutable_point_2d();
dst_point->set_x(dst.x);
dst_point->set_y(dst.y);
}
}
for (const auto& key : keys_to_be_deleted) {
detected_objects_.erase(key);
}
return frame_annotation;
}
} // namespace mediapipe