33d683c671
GitOrigin-RevId: 373e3ac1e5839befd95bf7d73ceff3c5f1171969
198 lines
7.1 KiB
C++
198 lines
7.1 KiB
C++
// Copyright 2021 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "mediapipe/calculators/util/landmarks_refinement_calculator.h"
|
|
|
|
#include <algorithm>
|
|
#include <set>
|
|
#include <utility>
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "mediapipe/calculators/util/landmarks_refinement_calculator.pb.h"
|
|
#include "mediapipe/framework/api2/node.h"
|
|
#include "mediapipe/framework/api2/port.h"
|
|
#include "mediapipe/framework/calculator_framework.h"
|
|
#include "mediapipe/framework/port/proto_ns.h"
|
|
#include "mediapipe/framework/port/ret_check.h"
|
|
|
|
namespace mediapipe {
|
|
|
|
namespace api2 {
|
|
|
|
namespace {
|
|
|
|
absl::StatusOr<int> GetNumberOfRefinedLandmarks(
|
|
const proto_ns::RepeatedPtrField<
|
|
LandmarksRefinementCalculatorOptions::Refinement>& refinements) {
|
|
// Gather all used indexes.
|
|
std::set<int> idxs;
|
|
for (int i = 0; i < refinements.size(); ++i) {
|
|
const auto& refinement = refinements.Get(i);
|
|
for (int i = 0; i < refinement.indexes_mapping_size(); ++i) {
|
|
idxs.insert(refinement.indexes_mapping(i));
|
|
}
|
|
}
|
|
|
|
// Check that indxes start with 0 and there is no gaps between min and max
|
|
// indexes.
|
|
RET_CHECK(!idxs.empty())
|
|
<< "There should be at least one landmark in indexes mapping";
|
|
int idxs_min = *idxs.begin();
|
|
int idxs_max = *idxs.rbegin();
|
|
int n_idxs = idxs.size();
|
|
RET_CHECK_EQ(idxs_min, 0)
|
|
<< "Indexes are expected to start with 0 instead of " << idxs_min;
|
|
RET_CHECK_EQ(idxs_max, n_idxs - 1)
|
|
<< "Indexes should have no gaps but " << idxs_max - n_idxs + 1
|
|
<< " indexes are missing";
|
|
|
|
return n_idxs;
|
|
}
|
|
|
|
void RefineXY(const proto_ns::RepeatedField<int>& indexes_mapping,
|
|
const NormalizedLandmarkList& landmarks,
|
|
NormalizedLandmarkList* refined_landmarks) {
|
|
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
|
const auto& landmark = landmarks.landmark(i);
|
|
auto* refined_landmark =
|
|
refined_landmarks->mutable_landmark(indexes_mapping.Get(i));
|
|
refined_landmark->set_x(landmark.x());
|
|
refined_landmark->set_y(landmark.y());
|
|
}
|
|
}
|
|
|
|
float GetZAverage(const NormalizedLandmarkList& landmarks,
|
|
const proto_ns::RepeatedField<int>& indexes) {
|
|
double z_sum = 0;
|
|
for (int i = 0; i < indexes.size(); ++i) {
|
|
z_sum += landmarks.landmark(indexes.Get(i)).z();
|
|
}
|
|
return z_sum / indexes.size();
|
|
}
|
|
|
|
void RefineZ(
|
|
const proto_ns::RepeatedField<int>& indexes_mapping,
|
|
const LandmarksRefinementCalculatorOptions::ZRefinement& z_refinement,
|
|
const NormalizedLandmarkList& landmarks,
|
|
NormalizedLandmarkList* refined_landmarks) {
|
|
if (z_refinement.has_none()) {
|
|
// Do nothing and keep Z that is already in refined landmarks.
|
|
} else if (z_refinement.has_copy()) {
|
|
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
|
refined_landmarks->mutable_landmark(indexes_mapping.Get(i))
|
|
->set_z(landmarks.landmark(i).z());
|
|
}
|
|
} else if (z_refinement.has_assign_average()) {
|
|
const float z_average =
|
|
GetZAverage(*refined_landmarks,
|
|
z_refinement.assign_average().indexes_for_average());
|
|
for (int i = 0; i < indexes_mapping.size(); ++i) {
|
|
refined_landmarks->mutable_landmark(indexes_mapping.Get(i))
|
|
->set_z(z_average);
|
|
}
|
|
} else {
|
|
CHECK(false) << "Z refinement is either not specified or not supported";
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
class LandmarksRefinementCalculatorImpl
|
|
: public NodeImpl<LandmarksRefinementCalculator> {
|
|
absl::Status Open(CalculatorContext* cc) override {
|
|
options_ = cc->Options<LandmarksRefinementCalculatorOptions>();
|
|
|
|
// Validate refinements.
|
|
for (int i = 0; i < options_.refinement_size(); ++i) {
|
|
const auto& refinement = options_.refinement(i);
|
|
RET_CHECK_GT(refinement.indexes_mapping_size(), 0)
|
|
<< "Refinement " << i << " has no indexes mapping";
|
|
RET_CHECK(refinement.has_z_refinement())
|
|
<< "Refinement " << i << " has no Z refinement specified";
|
|
RET_CHECK(refinement.z_refinement().has_none() ^
|
|
refinement.z_refinement().has_copy() ^
|
|
refinement.z_refinement().has_assign_average())
|
|
<< "Exactly one Z refinement should be specified";
|
|
|
|
const auto z_refinement = refinement.z_refinement();
|
|
if (z_refinement.has_assign_average()) {
|
|
RET_CHECK_GT(z_refinement.assign_average().indexes_for_average_size(),
|
|
0)
|
|
<< "When using assign average Z refinement at least one index for "
|
|
"averagin should be specified";
|
|
}
|
|
}
|
|
|
|
// Validate indexes mapping and get total number of refined landmarks.
|
|
ASSIGN_OR_RETURN(n_refined_landmarks_,
|
|
GetNumberOfRefinedLandmarks(options_.refinement()));
|
|
|
|
// Validate that number of refinements and landmark streams is the same.
|
|
RET_CHECK_EQ(kLandmarks(cc).Count(), options_.refinement_size())
|
|
<< "There are " << options_.refinement_size() << " refinements while "
|
|
<< kLandmarks(cc).Count() << " landmark streams";
|
|
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
absl::Status Process(CalculatorContext* cc) override {
|
|
// If any of the refinement landmarks is missing - refinement won't happen.
|
|
for (const auto& landmarks_stream : kLandmarks(cc)) {
|
|
if (landmarks_stream.IsEmpty()) {
|
|
return absl::OkStatus();
|
|
}
|
|
}
|
|
|
|
// Initialize refined landmarks list.
|
|
auto refined_landmarks = absl::make_unique<NormalizedLandmarkList>();
|
|
for (int i = 0; i < n_refined_landmarks_; ++i) {
|
|
refined_landmarks->add_landmark();
|
|
}
|
|
|
|
// Apply input landmarks to outpu refined landmarks in provided order.
|
|
for (int i = 0; i < kLandmarks(cc).Count(); ++i) {
|
|
const auto& landmarks = kLandmarks(cc)[i].Get();
|
|
const auto& refinement = options_.refinement(i);
|
|
|
|
// Check number of landmarks in mapping and stream are the same.
|
|
RET_CHECK_EQ(landmarks.landmark_size(), refinement.indexes_mapping_size())
|
|
<< "There are " << landmarks.landmark_size()
|
|
<< " refinement landmarks while mapping has "
|
|
<< refinement.indexes_mapping_size();
|
|
|
|
// Refine X and Y.
|
|
RefineXY(refinement.indexes_mapping(), landmarks,
|
|
refined_landmarks.get());
|
|
|
|
// Refine Z.
|
|
RefineZ(refinement.indexes_mapping(), refinement.z_refinement(),
|
|
landmarks, refined_landmarks.get());
|
|
|
|
// Visibility and presence are not currently refined and are left as `0`.
|
|
}
|
|
|
|
kRefinedLandmarks(cc).Send(std::move(refined_landmarks));
|
|
return absl::OkStatus();
|
|
}
|
|
|
|
private:
|
|
LandmarksRefinementCalculatorOptions options_;
|
|
int n_refined_landmarks_ = 0;
|
|
};
|
|
|
|
MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksRefinementCalculatorImpl);
|
|
|
|
} // namespace api2
|
|
} // namespace mediapipe
|