Only apply face landmarks smoothing for stream mode (VIDEO and LIVE_STREAM).
PiperOrigin-RevId: 536455842
This commit is contained in:
parent
fabde5f129
commit
056881f4a9
|
@ -454,7 +454,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
if (face_detector_options.num_faces() == 1) {
|
if (face_detector_options.num_faces() == 1) {
|
||||||
face_landmarks_detector_graph
|
face_landmarks_detector_graph
|
||||||
.GetOptions<FaceLandmarksDetectorGraphOptions>()
|
.GetOptions<FaceLandmarksDetectorGraphOptions>()
|
||||||
.set_smooth_landmarks(true);
|
.set_smooth_landmarks(tasks_options.base_options().use_stream_mode());
|
||||||
} else if (face_detector_options.num_faces() > 1 &&
|
} else if (face_detector_options.num_faces() > 1 &&
|
||||||
face_landmarks_detector_graph
|
face_landmarks_detector_graph
|
||||||
.GetOptions<FaceLandmarksDetectorGraphOptions>()
|
.GetOptions<FaceLandmarksDetectorGraphOptions>()
|
||||||
|
|
|
@ -14,11 +14,13 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
@ -31,6 +33,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/file_helpers.h"
|
#include "mediapipe/framework/port/file_helpers.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
@ -95,6 +98,8 @@ constexpr float kLandmarksDiffMargin = 0.03;
|
||||||
constexpr float kBlendshapesDiffMargin = 0.1;
|
constexpr float kBlendshapesDiffMargin = 0.1;
|
||||||
constexpr float kFaceGeometryDiffMargin = 0.02;
|
constexpr float kFaceGeometryDiffMargin = 0.02;
|
||||||
|
|
||||||
|
constexpr char kLandmarksSmoothingCalculator[] = "LandmarksSmoothingCalculator";
|
||||||
|
|
||||||
template <typename ProtoT>
|
template <typename ProtoT>
|
||||||
ProtoT GetExpectedProto(absl::string_view filename) {
|
ProtoT GetExpectedProto(absl::string_view filename) {
|
||||||
ProtoT expected_proto;
|
ProtoT expected_proto;
|
||||||
|
@ -103,6 +108,13 @@ ProtoT GetExpectedProto(absl::string_view filename) {
|
||||||
return expected_proto;
|
return expected_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct VerifyExpandedConfigTestParams {
|
||||||
|
std::string test_name;
|
||||||
|
bool use_stream_mode;
|
||||||
|
int num_faces;
|
||||||
|
bool has_smoothing_calculator;
|
||||||
|
};
|
||||||
|
|
||||||
// Struct holding the parameters for parameterized FaceLandmarkerGraphTest
|
// Struct holding the parameters for parameterized FaceLandmarkerGraphTest
|
||||||
// class.
|
// class.
|
||||||
struct FaceLandmarkerGraphTestParams {
|
struct FaceLandmarkerGraphTestParams {
|
||||||
|
@ -165,6 +177,25 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateFaceLandmarkerGraphTaskRunner(
|
||||||
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
|
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<CalculatorGraphConfig> ExpandConfig(
|
||||||
|
const std::string& config_str) {
|
||||||
|
auto config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_RETURN_IF_ERROR(graph.Initialize(config));
|
||||||
|
return graph.Config();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasCalculatorInConfig(const std::string& calculator_name,
|
||||||
|
const CalculatorGraphConfig& config) {
|
||||||
|
for (const auto& node : config.node()) {
|
||||||
|
if (node.calculator() == calculator_name) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to construct NormalizeRect proto.
|
// Helper function to construct NormalizeRect proto.
|
||||||
NormalizedRect MakeNormRect(float x_center, float y_center, float width,
|
NormalizedRect MakeNormRect(float x_center, float y_center, float width,
|
||||||
float height, float rotation) {
|
float height, float rotation) {
|
||||||
|
@ -177,6 +208,71 @@ NormalizedRect MakeNormRect(float x_center, float y_center, float width,
|
||||||
return face_rect;
|
return face_rect;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr char kGraphConfigString[] = R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"
|
||||||
|
input_stream: "IMAGE:image_in"
|
||||||
|
output_stream: "NORM_LANDMARKS:face_landmarks"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.vision.face_landmarker.proto.FaceLandmarkerGraphOptions
|
||||||
|
.ext] {
|
||||||
|
base_options {
|
||||||
|
model_asset {
|
||||||
|
file_name: "mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task"
|
||||||
|
}
|
||||||
|
use_stream_mode: $0
|
||||||
|
}
|
||||||
|
face_detector_graph_options { num_faces: $1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_stream: "IMAGE:image_in"
|
||||||
|
)pb";
|
||||||
|
|
||||||
|
class VerifyExpandedConfig
|
||||||
|
: public testing::TestWithParam<VerifyExpandedConfigTestParams> {};
|
||||||
|
|
||||||
|
TEST_P(VerifyExpandedConfig, Succeeds) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto actual_graph,
|
||||||
|
ExpandConfig(absl::Substitute(
|
||||||
|
kGraphConfigString, GetParam().use_stream_mode ? "true" : "false",
|
||||||
|
std::to_string(GetParam().num_faces))));
|
||||||
|
if (GetParam().has_smoothing_calculator) {
|
||||||
|
EXPECT_TRUE(
|
||||||
|
HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph));
|
||||||
|
} else {
|
||||||
|
EXPECT_FALSE(
|
||||||
|
HasCalculatorInConfig(kLandmarksSmoothingCalculator, actual_graph));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
VerifyExpandedConfig, VerifyExpandedConfig,
|
||||||
|
Values(VerifyExpandedConfigTestParams{
|
||||||
|
/*test_name=*/"NonStreamOneFaceHasNoSmoothing",
|
||||||
|
/*use_stream_mode=*/false,
|
||||||
|
/*num_faces=*/1,
|
||||||
|
/*has_smoothing_calculator=*/false},
|
||||||
|
VerifyExpandedConfigTestParams{
|
||||||
|
/*test_name=*/"NonStreamTwoFaceHasNoSmoothing",
|
||||||
|
/*use_stream_mode=*/false,
|
||||||
|
/*num_faces=*/2,
|
||||||
|
/*has_smoothing_calculator=*/false},
|
||||||
|
VerifyExpandedConfigTestParams{
|
||||||
|
/*test_name=*/"StreamOneFaceHasSmoothing",
|
||||||
|
/*use_stream_mode=*/true,
|
||||||
|
/*num_faces=*/1,
|
||||||
|
/*has_smoothing_calculator=*/true},
|
||||||
|
VerifyExpandedConfigTestParams{
|
||||||
|
/*test_name=*/"StreamTwoFaceHasNoSmoothing",
|
||||||
|
/*use_stream_mode=*/true,
|
||||||
|
/*num_faces=*/2,
|
||||||
|
/*has_smoothing_calculator=*/false}),
|
||||||
|
[](const TestParamInfo<VerifyExpandedConfig::ParamType>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|
||||||
class FaceLandmarkerGraphTest
|
class FaceLandmarkerGraphTest
|
||||||
: public testing::TestWithParam<FaceLandmarkerGraphTestParams> {};
|
: public testing::TestWithParam<FaceLandmarkerGraphTestParams> {};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user