Only apply face landmarks smoothing for stream mode (VIDEO and LIVE_STREAM).
PiperOrigin-RevId: 536455842
This commit is contained in:
parent
fabde5f129
commit
056881f4a9
|
@ -454,7 +454,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
|||
if (face_detector_options.num_faces() == 1) {
|
||||
face_landmarks_detector_graph
|
||||
.GetOptions<FaceLandmarksDetectorGraphOptions>()
|
||||
.set_smooth_landmarks(true);
|
||||
.set_smooth_landmarks(tasks_options.base_options().use_stream_mode());
|
||||
} else if (face_detector_options.num_faces() > 1 &&
|
||||
face_landmarks_detector_graph
|
||||
.GetOptions<FaceLandmarksDetectorGraphOptions>()
|
||||
|
|
|
@ -14,11 +14,13 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
@ -31,6 +33,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
|
@ -95,6 +98,8 @@ constexpr float kLandmarksDiffMargin = 0.03;
|
|||
constexpr float kBlendshapesDiffMargin = 0.1;
|
||||
constexpr float kFaceGeometryDiffMargin = 0.02;
|
||||
|
||||
constexpr char kLandmarksSmoothingCalculator[] = "LandmarksSmoothingCalculator";
|
||||
|
||||
template <typename ProtoT>
|
||||
ProtoT GetExpectedProto(absl::string_view filename) {
|
||||
ProtoT expected_proto;
|
||||
|
@ -103,6 +108,13 @@ ProtoT GetExpectedProto(absl::string_view filename) {
|
|||
return expected_proto;
|
||||
}
|
||||
|
||||
struct VerifyExpandedConfigTestParams {
|
||||
std::string test_name;
|
||||
bool use_stream_mode;
|
||||
int num_faces;
|
||||
bool has_smoothing_calculator;
|
||||
};
|
||||
|
||||
// Struct holding the parameters for parameterized FaceLandmarkerGraphTest
|
||||
// class.
|
||||
struct FaceLandmarkerGraphTestParams {
|
||||
|
@ -165,6 +177,25 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateFaceLandmarkerGraphTaskRunner(
|
|||
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
absl::StatusOr<CalculatorGraphConfig> ExpandConfig(
|
||||
const std::string& config_str) {
|
||||
auto config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(config));
|
||||
return graph.Config();
|
||||
}
|
||||
|
||||
bool HasCalculatorInConfig(const std::string& calculator_name,
|
||||
const CalculatorGraphConfig& config) {
|
||||
for (const auto& node : config.node()) {
|
||||
if (node.calculator() == calculator_name) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Helper function to construct NormalizeRect proto.
|
||||
NormalizedRect MakeNormRect(float x_center, float y_center, float width,
|
||||
float height, float rotation) {
|
||||
|
@ -177,6 +208,71 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width,
|
|||
return face_rect;
|
||||
}
|
||||
|
||||
constexpr char kGraphConfigString[] = R"pb(
|
||||
node {
|
||||
calculator: "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"
|
||||
input_stream: "IMAGE:image_in"
|
||||
output_stream: "NORM_LANDMARKS:face_landmarks"
|
||||
options {
|
||||
[mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarkerGraphOptions
|
||||
.ext] {
|
||||
base_options {
|
||||
model_asset {
|
||||
file_name: "mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task"
|
||||
}
|
||||
use_stream_mode: $0
|
||||
}
|
||||
face_detector_graph_options { num_faces: $1 }
|
||||
}
|
||||
}
|
||||
}
|
||||
input_stream: "IMAGE:image_in"
|
||||
)pb";
|
||||
|
||||
class VerifyExpandedConfig
|
||||
: public testing::TestWithParam<VerifyExpandedConfigTestParams> {};
|
||||
|
||||
TEST_P(VerifyExpandedConfig, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto actual_graph,
|
||||
ExpandConfig(absl::Substitute(
|
||||
kGraphConfigString, GetParam().use_stream_mode ? "true" : "false",
|
||||
std::to_string(GetParam().num_faces))));
|
||||
if (GetParam().has_smoothing_calculator) {
|
||||
EXPECT_TRUE(
|
||||
HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph));
|
||||
} else {
|
||||
EXPECT_FALSE(
|
||||
HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph));
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
VerifyExpandedConfig, VerifyExpandedConfig,
|
||||
Values(VerifyExpandedConfigTestParams{
|
||||
/*test_name=*/"NonStreamOneFaceHasNoSmoothing",
|
||||
/*use_stream_mode=*/false,
|
||||
/*num_faces=*/1,
|
||||
/*has_smoothing_calculator=*/false},
|
||||
VerifyExpandedConfigTestParams{
|
||||
/*test_name=*/"NonStreamTwoFaceHasNoSmoothing",
|
||||
/*use_stream_mode=*/false,
|
||||
/*num_faces=*/2,
|
||||
/*has_smoothing_calculator=*/false},
|
||||
VerifyExpandedConfigTestParams{
|
||||
/*test_name=*/"StreamOneFaceHasSmoothing",
|
||||
/*use_stream_mode=*/true,
|
||||
/*num_faces=*/1,
|
||||
/*has_smoothing_calculator=*/true},
|
||||
VerifyExpandedConfigTestParams{
|
||||
/*test_name=*/"StreamTwoFaceHasNoSmoothing",
|
||||
/*use_stream_mode=*/true,
|
||||
/*num_faces=*/2,
|
||||
/*has_smoothing_calculator=*/false}),
|
||||
[](const TestParamInfo<VerifyExpandedConfig::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
class FaceLandmarkerGraphTest
|
||||
: public testing::TestWithParam<FaceLandmarkerGraphTestParams> {};
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user