Implement face stylizer graph and its C++ API.

PiperOrigin-RevId: 515139282
This commit is contained in:
Jiuqiang Tang 2023-03-08 14:16:08 -08:00 committed by Copybara-Service
parent 0a60c67667
commit 253a5b477e
10 changed files with 995 additions and 94 deletions

View File

@ -0,0 +1,73 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
cc_library(
name = "face_stylizer_graph",
srcs = ["face_stylizer_graph.cc"],
deps = [
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
"//mediapipe/calculators/image:warp_affine_calculator",
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:inverse_matrix_calculator",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "face_stylizer",
srcs = ["face_stylizer.cc"],
hdrs = ["face_stylizer.h"],
deps = [
":face_stylizer_graph", # buildcleaner:keep
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)

View File

@ -50,6 +50,7 @@ cc_library(
":tensors_to_image_calculator_cc_proto", ":tensors_to_image_calculator_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"//mediapipe/calculators/tensor:image_to_tensor_utils",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
@ -57,8 +58,11 @@ cc_library(
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
@ -106,3 +110,16 @@ cc_library(
], ],
}), }),
) )
cc_library(
name = "strip_rotation_calculator",
srcs = ["strip_rotation_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:rect_cc_proto",
],
alwayslink = 1,
)

View File

@ -0,0 +1,50 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/rect.pb.h"
namespace mediapipe {
namespace tasks {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
} // namespace
// A calculator to strip the rotation information from the NormalizedRect.
class StripRotationCalculator : public Node {
public:
static constexpr Input<NormalizedRect> kInNormRect{"NORM_RECT"};
static constexpr Output<NormalizedRect> kOutNormRect{"NORM_RECT"};
MEDIAPIPE_NODE_CONTRACT(kInNormRect, kOutNormRect);
absl::Status Process(CalculatorContext* cc) {
if (!kInNormRect(cc).IsEmpty()) {
NormalizedRect rect = kInNormRect(cc).Get();
rect.clear_rotation();
kOutNormRect(cc).Send(rect);
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::StripRotationCalculator);
} // namespace tasks
} // namespace mediapipe

View File

@ -18,14 +18,18 @@
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/calculators/tensor/image_to_tensor_utils.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h"
@ -74,6 +78,15 @@ static int NumGroups(const int size, const int group_size) { // NOLINT
return (size + group_size - 1) / group_size; return (size + group_size - 1) / group_size;
} }
bool CanUseGpu() {
#if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED
constexpr bool kAllowGpuProcessing = true;
return kAllowGpuProcessing;
#else
return false;
#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED
}
} // namespace } // namespace
// Converts a MediaPipe tensor to a MediaPipe Image. // Converts a MediaPipe tensor to a MediaPipe Image.
@ -83,8 +96,6 @@ static int NumGroups(const int size, const int group_size) { // NOLINT
// //
// Output streams: // Output streams:
// OUTPUT - mediapipe::Image. // OUTPUT - mediapipe::Image.
//
// TODO: Enable TensorsToImageCalculator to run on CPU.
class TensorsToImageCalculator : public Node { class TensorsToImageCalculator : public Node {
public: public:
static constexpr Input<std::vector<Tensor>> kInputTensors{"TENSORS"}; static constexpr Input<std::vector<Tensor>> kInputTensors{"TENSORS"};
@ -98,6 +109,9 @@ class TensorsToImageCalculator : public Node {
absl::Status Close(CalculatorContext* cc); absl::Status Close(CalculatorContext* cc);
private: private:
TensorsToImageCalculatorOptions options_;
absl::Status CpuProcess(CalculatorContext* cc);
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
bool metal_initialized_ = false; bool metal_initialized_ = false;
@ -108,7 +122,7 @@ class TensorsToImageCalculator : public Node {
absl::Status MetalProcess(CalculatorContext* cc); absl::Status MetalProcess(CalculatorContext* cc);
#else #else
absl::Status GlSetup(CalculatorContext* cc); absl::Status GlSetup(CalculatorContext* cc);
absl::Status GlProcess(CalculatorContext* cc);
GlCalculatorHelper gl_helper_; GlCalculatorHelper gl_helper_;
bool gl_initialized_ = false; bool gl_initialized_ = false;
@ -136,112 +150,37 @@ absl::Status TensorsToImageCalculator::UpdateContract(CalculatorContract* cc) {
} }
absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) {
options_ = cc->Options<TensorsToImageCalculatorOptions>();
if (CanUseGpu()) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
#else #else
MP_RETURN_IF_ERROR(gl_helper_.Open(cc)); MP_RETURN_IF_ERROR(gl_helper_.Open(cc));
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} else {
CHECK(options_.has_input_tensor_float_range() ^
options_.has_input_tensor_uint_range())
<< "Must specify either `input_tensor_float_range` or "
"`input_tensor_uint_range` in the calculator options";
}
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status TensorsToImageCalculator::Process(CalculatorContext* cc) { absl::Status TensorsToImageCalculator::Process(CalculatorContext* cc) {
if (CanUseGpu()) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
return MetalProcess(cc);
return MetalProcess(cc);
#else #else
return GlProcess(cc);
return gl_helper_.RunInGlContext([this, cc]() -> absl::Status {
if (!gl_initialized_) {
MP_RETURN_IF_ERROR(GlSetup(cc));
gl_initialized_ = true;
}
if (kInputTensors(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = kInputTensors(cc).Get();
RET_CHECK_EQ(input_tensors.size(), 1)
<< "Expect 1 input tensor, but have " << input_tensors.size();
const int tensor_width = input_tensors[0].shape().dims[2];
const int tensor_height = input_tensors[0].shape().dims[1];
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
auto out_texture = std::make_unique<tflite::gpu::gl::GlTexture>();
MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture(
tflite::gpu::DataType::UINT8, // GL_RGBA8
{tensor_width, tensor_height}, out_texture.get()));
const int output_index = 0;
glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0,
GL_WRITE_ONLY, GL_RGBA8);
auto read_view = input_tensors[0].GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name());
const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1};
const tflite::gpu::uint3 workgroups =
tflite::gpu::DivideRoundUp(workload, workgroup_size_);
glUseProgram(gl_compute_program_->id());
glUniform2i(glGetUniformLocation(gl_compute_program_->id(), "out_size"),
tensor_width, tensor_height);
MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups));
auto texture_buffer = mediapipe::GlTextureBuffer::Wrap(
out_texture->target(), out_texture->id(), tensor_width, tensor_height,
mediapipe::GpuBufferFormat::kBGRA32,
[ptr = out_texture.release()](
std::shared_ptr<mediapipe::GlSyncPoint> sync_token) mutable {
delete ptr;
});
auto output =
std::make_unique<mediapipe::GpuBuffer>(std::move(texture_buffer));
kOutputImage(cc).Send(Image(*output));
;
#else
if (!input_tensors[0].ready_as_opengl_texture_2d()) {
(void)input_tensors[0].GetCpuReadView();
}
auto output_texture =
gl_helper_.CreateDestinationTexture(tensor_width, tensor_height);
gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D,
input_tensors[0].GetOpenGlTexture2dReadView().name());
MP_RETURN_IF_ERROR(gl_renderer_->GlRender(
tensor_width, tensor_height, output_texture.width(),
output_texture.height(), mediapipe::FrameScaleMode::kStretch,
mediapipe::FrameRotation::kNone,
/*flip_horizontal=*/false, /*flip_vertical=*/false,
/*flip_texture=*/false));
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, 0);
auto output = output_texture.GetFrame<GpuBuffer>();
kOutputImage(cc).Send(Image(*output));
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
return mediapipe::OkStatus();
});
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
return absl::OkStatus(); }
return CpuProcess(cc);
} }
absl::Status TensorsToImageCalculator::Close(CalculatorContext* cc) { absl::Status TensorsToImageCalculator::Close(CalculatorContext* cc) {
@ -258,6 +197,61 @@ absl::Status TensorsToImageCalculator::Close(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) {
if (kInputTensors(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = kInputTensors(cc).Get();
RET_CHECK_EQ(input_tensors.size(), 1)
<< "Expect 1 input tensor, but have " << input_tensors.size();
const auto& input_tensor = input_tensors[0];
const int tensor_in_height = input_tensor.shape().dims[1];
const int tensor_in_width = input_tensor.shape().dims[2];
const int tensor_in_channels = input_tensor.shape().dims[3];
RET_CHECK_EQ(tensor_in_channels, 3);
auto output_frame = std::make_shared<ImageFrame>(
mediapipe::ImageFormat::SRGB, tensor_in_width, tensor_in_height);
cv::Mat output_matview = mediapipe::formats::MatView(output_frame.get());
constexpr float kOutputImageRangeMin = 0.0f;
constexpr float kOutputImageRangeMax = 255.0f;
if (input_tensor.element_type() == Tensor::ElementType::kFloat32) {
cv::Mat tensor_matview(
cv::Size(tensor_in_width, tensor_in_height),
CV_MAKETYPE(CV_32F, tensor_in_channels),
const_cast<float*>(input_tensor.GetCpuReadView().buffer<float>()));
auto input_range = options_.input_tensor_float_range();
ASSIGN_OR_RETURN(auto transform,
GetValueRangeTransformation(
input_range.min(), input_range.max(),
kOutputImageRangeMin, kOutputImageRangeMax));
tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale,
transform.offset);
} else if (input_tensor.element_type() == Tensor::ElementType::kUInt8) {
cv::Mat tensor_matview(
cv::Size(tensor_in_width, tensor_in_height),
CV_MAKETYPE(CV_8U, tensor_in_channels),
const_cast<uint8_t*>(input_tensor.GetCpuReadView().buffer<uint8_t>()));
auto input_range = options_.input_tensor_uint_range();
ASSIGN_OR_RETURN(auto transform,
GetValueRangeTransformation(
input_range.min(), input_range.max(),
kOutputImageRangeMin, kOutputImageRangeMax));
tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale,
transform.offset);
} else {
return absl::InvalidArgumentError(
absl::Substitute("Type of tensor must be kFloat32 or kUInt8, got: $0",
input_tensor.element_type()));
}
kOutputImage(cc).Send(Image(output_frame));
return absl::OkStatus();
}
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) {
@ -433,6 +427,90 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) {
return mediapipe::OkStatus(); return mediapipe::OkStatus();
} }
absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) {
return gl_helper_.RunInGlContext([this, cc]() -> absl::Status {
if (!gl_initialized_) {
MP_RETURN_IF_ERROR(GlSetup(cc));
gl_initialized_ = true;
}
if (kInputTensors(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = kInputTensors(cc).Get();
RET_CHECK_EQ(input_tensors.size(), 1)
<< "Expect 1 input tensor, but have " << input_tensors.size();
const int tensor_width = input_tensors[0].shape().dims[2];
const int tensor_height = input_tensors[0].shape().dims[1];
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
auto out_texture = std::make_unique<tflite::gpu::gl::GlTexture>();
MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture(
tflite::gpu::DataType::UINT8, // GL_RGBA8
{tensor_width, tensor_height}, out_texture.get()));
const int output_index = 0;
glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0,
GL_WRITE_ONLY, GL_RGBA8);
auto read_view = input_tensors[0].GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name());
const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1};
const tflite::gpu::uint3 workgroups =
tflite::gpu::DivideRoundUp(workload, workgroup_size_);
glUseProgram(gl_compute_program_->id());
glUniform2i(glGetUniformLocation(gl_compute_program_->id(), "out_size"),
tensor_width, tensor_height);
MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups));
auto texture_buffer = mediapipe::GlTextureBuffer::Wrap(
out_texture->target(), out_texture->id(), tensor_width, tensor_height,
mediapipe::GpuBufferFormat::kBGRA32,
[ptr = out_texture.release()](
std::shared_ptr<mediapipe::GlSyncPoint> sync_token) mutable {
delete ptr;
});
auto output =
std::make_unique<mediapipe::GpuBuffer>(std::move(texture_buffer));
kOutputImage(cc).Send(Image(*output));
#else
if (!input_tensors[0].ready_as_opengl_texture_2d()) {
(void)input_tensors[0].GetCpuReadView();
}
auto output_texture =
gl_helper_.CreateDestinationTexture(tensor_width, tensor_height);
gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D,
input_tensors[0].GetOpenGlTexture2dReadView().name());
MP_RETURN_IF_ERROR(gl_renderer_->GlRender(
tensor_width, tensor_height, output_texture.width(),
output_texture.height(), mediapipe::FrameScaleMode::kStretch,
mediapipe::FrameRotation::kNone,
/*flip_horizontal=*/false, /*flip_vertical=*/false,
/*flip_texture=*/false));
glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, 0);
auto output = output_texture.GetFrame<GpuBuffer>();
kOutputImage(cc).Send(Image(*output));
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
return mediapipe::OkStatus();
});
}
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
} // namespace tasks } // namespace tasks

View File

@ -28,4 +28,24 @@ message TensorsToImageCalculatorOptions {
// to be flipped vertically as tensors are expected to start at top. // to be flipped vertically as tensors are expected to start at top.
// (DEFAULT or unset interpreted as CONVENTIONAL.) // (DEFAULT or unset interpreted as CONVENTIONAL.)
optional mediapipe.GpuOrigin.Mode gpu_origin = 1; optional mediapipe.GpuOrigin.Mode gpu_origin = 1;
// Range of float values [min, max].
// min, must be strictly less than max.
message FloatRange {
optional float min = 1;
optional float max = 2;
}
// Range of uint values [min, max].
// min, must be strictly less than max.
message UIntRange {
optional uint64 min = 1;
optional uint64 max = 2;
}
// The input tensor element range/type image pixels are converted from.
oneof range {
FloatRange input_tensor_float_range = 2;
UIntRange input_tensor_uint_range = 3;
}
} }

View File

@ -0,0 +1,196 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h"
#include <stdint.h>
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_stylizer {
namespace {
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph";
constexpr char kStylizedImageTag[] = "STYLIZED_IMAGE";
constexpr char kStylizedImageName[] = "stylized_image";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using FaceStylizerGraphOptionsProto =
::mediapipe::tasks::vision::face_stylizer::proto::FaceStylizerGraphOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<FaceStylizerGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<FaceStylizerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectName);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
task_subgraph.Out(kStylizedImageTag).SetName(kStylizedImageName) >>
graph.Out(kStylizedImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, task_subgraph, {kImageTag, kNormRectTag}, kStylizedImageTag);
}
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing FaceStylizerOptions struct to the internal
// FaceStylizerGraphOptions proto.
std::unique_ptr<FaceStylizerGraphOptionsProto>
ConvertFaceStylizerOptionsToProto(FaceStylizerOptions* options) {
auto options_proto = std::make_unique<FaceStylizerGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<FaceStylizer>> FaceStylizer::Create(
std::unique_ptr<FaceStylizerOptions> options) {
auto options_proto = ConvertFaceStylizerOptionsToProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
Image image;
result_callback(status_or_packets.status(), image,
Timestamp::Unset().Value());
return;
}
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet stylized_image_packet =
status_or_packets.value()[kStylizedImageName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(stylized_image_packet.Get<Image>(),
image_packet.Get<Image>(),
stylized_image_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
return core::VisionTaskApiFactory::Create<FaceStylizer,
FaceStylizerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
}
absl::StatusOr<Image> FaceStylizer::Stylize(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kStylizedImageName].Get<Image>();
}
absl::StatusOr<Image> FaceStylizer::StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kStylizedImageName].Get<Image>();
}
absl::Status FaceStylizer::StylizeAsync(
Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace face_stylizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,156 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_STYLIZER_FACE_STYLIZER_H_
#define MEDIAPIPE_TASKS_CC_VISION_FACE_STYLIZER_FACE_STYLIZER_H_
#include <functional>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_stylizer {
// The options for configuring a mediapipe face stylizer task.
struct FaceStylizerOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// Face stylizer has three running modes:
// 1) The image mode for stylizing faces on single image inputs.
// 2) The video mode for stylizing faces on the decoded frames of a video.
// 3) The live stream mode for stylizing faces on the live stream of input
// data, such as from camera. In this mode, the "result_callback" below must
// be specified to receive the stylization results asynchronously.
core::RunningMode running_mode = core::RunningMode::IMAGE;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<mediapipe::Image>, const Image&, int64_t)>
result_callback = nullptr;
};
// Performs face stylization on images.
class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates a FaceStylizer from the provided options.
static absl::StatusOr<std::unique_ptr<FaceStylizer>> Create(
std::unique_ptr<FaceStylizerOptions> options);
// Performs face stylization on the provided single image.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing stylization, by
// setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform stylization, by setting its
// 'region_of_interest' field. If not specified, the full image is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the FaceStylizer is created with the image
// running mode.
//
// The input image can be of any size with format RGB or RGBA.
// To ensure that the output image has reasonable quailty, the stylized output
// image size is the smaller of the model output size and the size of the
// 'region_of_interest' specified in 'image_processing_options'.
absl::StatusOr<mediapipe::Image> Stylize(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs face stylization on the provided video frame.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing stylization, by
// setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform stylization, by setting its
// 'region_of_interest' field. If not specified, the full image is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the FaceStylizer is created with the video
// running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
// To ensure that the output image has reasonable quailty, the stylized output
// image size is the smaller of the model output size and the size of the
// 'region_of_interest' specified in 'image_processing_options'.
absl::StatusOr<mediapipe::Image> StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform face stylization, and the results will
// be available via the "result_callback" provided in the
// FaceStylizerOptions.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing stylization, by
// setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform stylization, by setting its
// 'region_of_interest' field. If not specified, the full image is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the FaceStylizer is created with the live stream
// running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the face stylizer. The input timestamps must be monotonically
// increasing.
//
// The "result_callback" provides:
// - The stylized image which size is the smaller of the model output size
// and the size of the 'region_of_interest' specified in
// 'image_processing_options'.
// - The input timestamp in milliseconds.
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the FaceStylizer when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace face_stylizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_STYLIZER_FACE_STYLIZER_H_

View File

@ -0,0 +1,246 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace face_stylizer {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::TensorsToImageCalculatorOptions;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::vision::face_stylizer::proto::
FaceStylizerGraphOptions;
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kMatrixTag[] = "MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kStylizedImageTag[] = "STYLIZED_IMAGE";
constexpr char kTensorsTag[] = "TENSORS";
// Struct holding the different output streams produced by the face stylizer
// graph.
struct FaceStylizerOutputStreams {
Source<Image> stylized_image;
Source<Image> original_image;
};
void ConfigureTensorsToImageCalculator(
const ImageToTensorCalculatorOptions& image_to_tensor_options,
TensorsToImageCalculatorOptions* tensors_to_image_options) {
tensors_to_image_options->set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
if (image_to_tensor_options.has_output_tensor_float_range()) {
auto* mutable_range =
tensors_to_image_options->mutable_input_tensor_float_range();
// TODO: Make the float range flexiable.
mutable_range->set_min(0);
mutable_range->set_max(1);
} else if (image_to_tensor_options.has_output_tensor_uint_range()) {
auto* mutable_range =
tensors_to_image_options->mutable_input_tensor_uint_range();
const auto& reference_range =
image_to_tensor_options.output_tensor_uint_range();
mutable_range->set_min(reference_range.min());
mutable_range->set_max(reference_range.max());
}
}
} // namespace
// A "mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph" performs face
// stylization.
//
// Inputs:
// IMAGE - Image
// Image to perform face stylization on.
// NORM_RECT - NormalizedRect @Optional
// Describes region of image to perform classification on.
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// IMAGE - mediapipe::Image
// The face stylization output image.
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph"
// input_stream: "IMAGE:image_in"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "IMAGE:image_out"
// output_stream: "STYLIZED_IMAGE:stylized_image"
// options {
// [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "face_stylization.tflite"
// }
// }
// }
// }
// }
class FaceStylizerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<FaceStylizerGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
BuildFaceStylizerGraph(
sc->Options<FaceStylizerGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.stylized_image >> graph[Output<Image>(kStylizedImageTag)];
output_streams.original_image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<FaceStylizerOutputStreams> BuildFaceStylizerGraph(
const FaceStylizerGraphOptions& task_options,
const ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Adds preprocessing calculators and connects them to the graph input image
// stream.
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
auto& image_to_tensor_options =
*preprocessing
.GetOptions<components::processors::proto::
ImagePreprocessingGraphOptions>()
.mutable_image_to_tensor_options();
image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto transform_matrix = preprocessing.Out(kMatrixTag);
auto image_size = preprocessing.Out(kImageSizeTag);
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
preprocessed_tensors >> inference.In(kTensorsTag);
auto model_output_tensors =
inference.Out(kTensorsTag).Cast<std::vector<Tensor>>();
auto& tensors_to_image =
graph.AddNode("mediapipe.tasks.TensorsToImageCalculator");
ConfigureTensorsToImageCalculator(
image_to_tensor_options,
&tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>());
model_output_tensors >> tensors_to_image.In(kTensorsTag);
auto tensor_image = tensors_to_image.Out(kImageTag);
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
transform_matrix >> inverse_matrix.In(kMatrixTag);
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
auto& warp_affine_options =
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
warp_affine_options.set_border_mode(
WarpAffineCalculatorOptions::BORDER_ZERO);
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
tensor_image >> warp_affine.In(kImageTag);
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
image_size >> warp_affine.In(kOutputSizeTag);
auto image_to_crop = warp_affine.Out(kImageTag);
// The following calculators are for cropping and resizing the output image
// based on the roi and the model output size. As the WarpAffineCalculator
// rotates the image based on the transform matrix, the rotation info in the
// rect proto is stripped to prevent the ImageCroppingCalculator from
// performing extra rotation.
auto& strip_rotation =
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
norm_rect_in >> strip_rotation.In(kNormRectTag);
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
auto& from_image = graph.AddNode("FromImageCalculator");
image_to_crop >> from_image.In(kImageTag);
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
auto& image_cropping_opts =
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
image_cropping_opts.set_output_max_width(
image_to_tensor_options.output_tensor_width());
image_cropping_opts.set_output_max_height(
image_to_tensor_options.output_tensor_height());
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
auto& to_image = graph.AddNode("ToImageCalculator");
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
// graph selects its cpu or gpu path based on the image preprocessing
// backend.
if (use_gpu) {
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
} else {
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
}
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::face_stylizer::FaceStylizerGraph); // NOLINT
// clang-format on
} // namespace face_stylizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "face_stylizer_graph_options_proto",
srcs = ["face_stylizer_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,34 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.face_stylizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.facestylizer.proto";
option java_outer_classname = "FaceStylizerGraphOptionsProto";
message FaceStylizerGraphOptions {
extend mediapipe.CalculatorOptions {
optional FaceStylizerGraphOptions ext = 513916220;
}
// Base options for configuring face stylizer, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
}