Add the "FACE_ALIGNMENT" output stream to the face stylizer graph.

PiperOrigin-RevId: 528345204
This commit is contained in:
Jiuqiang Tang 2023-04-30 16:57:32 -07:00 committed by Copybara-Service
parent c450283715
commit c29e43dda0
2 changed files with 87 additions and 13 deletions

View File

@ -50,6 +50,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
],

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe {
namespace tasks {
@ -64,6 +65,7 @@ using ::mediapipe::tasks::vision::face_stylizer::proto::
FaceStylizerGraphOptions;
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kFaceAlignmentTag[] = "FACE_ALIGNMENT";
constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
constexpr char kFaceLandmarksDetectorTFLiteName[] =
"face_landmarks_detector.tflite";
@ -83,7 +85,8 @@ constexpr char kTensorsTag[] = "TENSORS";
// Struct holding the different output streams produced by the face stylizer
// graph.
struct FaceStylizerOutputStreams {
Source<Image> stylized_image;
std::optional<Source<Image>> stylized_image;
std::optional<Source<Image>> face_alignment_image;
Source<Image> original_image;
};
@ -190,8 +193,13 @@ void ConfigureTensorsToImageCalculator(
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// IMAGE - mediapipe::Image
// STYLIZED_IMAGE - mediapipe::Image
// The face stylization output image.
// FACE_ALIGNMENT - mediapipe::Image
// The face alignment output image.
// IMAGE - mediapipe::Image
// The input image that the face landmarker runs on and has the pixel data
// stored on the target storage (CPU vs GPU).
//
// Example:
// node {
@ -215,6 +223,8 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
bool output_stylized = !HasInput(sc->OriginalNode(), kStylizedImageTag);
bool output_alignment = !HasInput(sc->OriginalNode(), kFaceAlignmentTag);
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<FaceStylizerGraphOptions>(sc));
@ -235,15 +245,27 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
->mutable_face_landmarker_graph_options(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
const ModelResources* face_stylizer_model_resources;
if (output_stylized) {
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources(sc, std::move(face_stylizer_external_file)));
face_stylizer_model_resources = model_resources;
}
ASSIGN_OR_RETURN(
auto output_streams,
BuildFaceStylizerGraph(sc->Options<FaceStylizerGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
face_stylizer_model_resources, output_alignment,
graph[Input<Image>(kImageTag)],
face_landmark_lists, graph));
output_streams.stylized_image >> graph[Output<Image>(kStylizedImageTag)];
if (output_stylized) {
output_streams.stylized_image.value() >>
graph[Output<Image>(kStylizedImageTag)];
}
if (output_alignment) {
output_streams.stylized_image.value() >>
graph[Output<Image>(kFaceAlignmentTag)];
}
output_streams.original_image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -277,9 +299,11 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
absl::StatusOr<FaceStylizerOutputStreams> BuildFaceStylizerGraph(
const FaceStylizerGraphOptions& task_options,
const ModelResources& model_resources, Source<Image> image_in,
const ModelResources* model_resources, bool output_alignment,
Source<Image> image_in,
Source<std::vector<NormalizedLandmarkList>> face_landmark_lists,
Graph& graph) {
bool output_stylized = model_resources != nullptr;
auto& split_face_landmark_list =
graph.AddNode("SplitNormalizedLandmarkListVectorCalculator");
ConfigureSplitNormalizedLandmarkListVectorCalculator(
@ -303,15 +327,52 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
face_detection >> face_to_rect.In(kDetectionTag);
image_size >> face_to_rect.In(kImageSizeTag);
auto face_rect = face_to_rect.Out(kNormRectTag);
// Adds preprocessing calculators and connects them to the graph input image
// stream.
std::optional<Source<Image>> face_alignment;
// Output face alignment only.
// In this case, the face stylization model inference is not required.
// However, to keep consistent with the inference preprocessing steps, the
// ImageToTensorCalculator is still used to perform image rotation,
// cropping, and resizing.
if (!output_stylized) {
auto& pass_through = graph.AddNode("PassThroughCalculator");
image_in >> pass_through.In("");
auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator");
auto& image_to_tensor_options =
image_to_tensor.GetOptions<ImageToTensorCalculatorOptions>();
image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1);
image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1);
image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
image_in >> image_to_tensor.In(kImageTag);
face_rect >> image_to_tensor.In(kNormRectTag);
auto face_alignment_image = image_to_tensor.Out(kTensorsTag);
auto& tensors_to_image =
graph.AddNode("mediapipe.tasks.TensorsToImageCalculator");
ConfigureTensorsToImageCalculator(
image_to_tensor_options,
&tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>());
face_alignment_image >> tensors_to_image.In(kTensorsTag);
face_alignment = tensors_to_image.Out(kImageTag).Cast<Image>();
return {{/*stylized_image=*/std::nullopt,
/*alignment_image=*/face_alignment,
/*original_image=*/pass_through.Out("").Cast<Image>()}};
}
std::optional<Source<Image>> stylized;
// Adds preprocessing calculators and connects them to the graph input
// image stream.
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
*model_resources, use_gpu,
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
auto& image_to_tensor_options =
@ -329,7 +390,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
*model_resources, task_options.base_options().acceleration(), graph);
preprocessed_tensors >> inference.In(kTensorsTag);
auto model_output_tensors =
inference.Out(kTensorsTag).Cast<std::vector<Tensor>>();
@ -346,8 +407,20 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
.set_output_on_gpu(false);
tensor_image >> image_converter.In("");
stylized = image_converter.Out("").Cast<Image>();
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
if (output_alignment) {
auto& tensors_to_image =
graph.AddNode("mediapipe.tasks.TensorsToImageCalculator");
ConfigureTensorsToImageCalculator(
image_to_tensor_options,
&tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>());
preprocessed_tensors >> tensors_to_image.In(kTensorsTag);
face_alignment = tensors_to_image.Out(kImageTag).Cast<Image>();
}
return {{/*stylized_image=*/stylized,
/*alignment_image=*/face_alignment,
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
}
};