Only apply face landmarks smoothing for stream mode (VIDEO and LIVE_STREAM).

PiperOrigin-RevId: 536455842
This commit is contained in:
MediaPipe Team 2023-05-30 11:14:09 -07:00 committed by Copybara-Service
parent fabde5f129
commit 056881f4a9
2 changed files with 97 additions and 1 deletions

View File

@ -454,7 +454,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
if (face_detector_options.num_faces() == 1) {
face_landmarks_detector_graph
.GetOptions<FaceLandmarksDetectorGraphOptions>()
.set_smooth_landmarks(true);
.set_smooth_landmarks(tasks_options.base_options().use_stream_mode());
} else if (face_detector_options.num_faces() > 1 &&
face_landmarks_detector_graph
.GetOptions<FaceLandmarksDetectorGraphOptions>()

View File

@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include <optional>
#include <string>
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
@ -31,6 +33,7 @@ limitations under the License.
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
@ -95,6 +98,8 @@ constexpr float kLandmarksDiffMargin = 0.03;
constexpr float kBlendshapesDiffMargin = 0.1;
constexpr float kFaceGeometryDiffMargin = 0.02;
constexpr char kLandmarksSmoothingCalculator[] = "LandmarksSmoothingCalculator";
template <typename ProtoT>
ProtoT GetExpectedProto(absl::string_view filename) {
ProtoT expected_proto;
@ -103,6 +108,13 @@ ProtoT GetExpectedProto(absl::string_view filename) {
return expected_proto;
}
struct VerifyExpandedConfigTestParams {
std::string test_name;
bool use_stream_mode;
int num_faces;
bool has_smoothing_calculator;
};
// Struct holding the parameters for parameterized FaceLandmarkerGraphTest
// class.
struct FaceLandmarkerGraphTestParams {
@ -165,6 +177,25 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateFaceLandmarkerGraphTaskRunner(
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
}
absl::StatusOr<CalculatorGraphConfig> ExpandConfig(
const std::string& config_str) {
auto config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(config));
return graph.Config();
}
bool HasCalculatorInConfig(const std::string& calculator_name,
const CalculatorGraphConfig& config) {
for (const auto& node : config.node()) {
if (node.calculator() == calculator_name) {
return true;
}
}
return false;
}
// Helper function to construct NormalizeRect proto.
NormalizedRect MakeNormRect(float x_center, float y_center, float width,
float height, float rotation) {
@ -177,6 +208,71 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width,
return face_rect;
}
constexpr char kGraphConfigString[] = R"pb(
node {
calculator: "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"
input_stream: "IMAGE:image_in"
output_stream: "NORM_LANDMARKS:face_landmarks"
options {
[mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarkerGraphOptions
.ext] {
base_options {
model_asset {
file_name: "mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task"
}
use_stream_mode: $0
}
face_detector_graph_options { num_faces: $1 }
}
}
}
input_stream: "IMAGE:image_in"
)pb";
class VerifyExpandedConfig
: public testing::TestWithParam<VerifyExpandedConfigTestParams> {};
TEST_P(VerifyExpandedConfig, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(
auto actual_graph,
ExpandConfig(absl::Substitute(
kGraphConfigString, GetParam().use_stream_mode ? "true" : "false",
std::to_string(GetParam().num_faces))));
if (GetParam().has_smoothing_calculator) {
EXPECT_TRUE(
HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph));
} else {
EXPECT_FALSE(
HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph));
}
}
INSTANTIATE_TEST_SUITE_P(
VerifyExpandedConfig, VerifyExpandedConfig,
Values(VerifyExpandedConfigTestParams{
/*test_name=*/"NonStreamOneFaceHasNoSmoothing",
/*use_stream_mode=*/false,
/*num_faces=*/1,
/*has_smoothing_calculator=*/false},
VerifyExpandedConfigTestParams{
/*test_name=*/"NonStreamTwoFaceHasNoSmoothing",
/*use_stream_mode=*/false,
/*num_faces=*/2,
/*has_smoothing_calculator=*/false},
VerifyExpandedConfigTestParams{
/*test_name=*/"StreamOneFaceHasSmoothing",
/*use_stream_mode=*/true,
/*num_faces=*/1,
/*has_smoothing_calculator=*/true},
VerifyExpandedConfigTestParams{
/*test_name=*/"StreamTwoFaceHasNoSmoothing",
/*use_stream_mode=*/true,
/*num_faces=*/2,
/*has_smoothing_calculator=*/false}),
[](const TestParamInfo<VerifyExpandedConfig::ParamType>& info) {
return info.param.test_name;
});
class FaceLandmarkerGraphTest
: public testing::TestWithParam<FaceLandmarkerGraphTestParams> {};