Internal change

PiperOrigin-RevId: 524317439
This commit is contained in:
Jiuqiang Tang 2023-04-14 09:58:32 -07:00 committed by Copybara-Service
parent 89e6b824ae
commit e6bb9e29c0
10 changed files with 348 additions and 109 deletions

View File

@ -22,29 +22,42 @@ cc_library(
name = "face_stylizer_graph", name = "face_stylizer_graph",
srcs = ["face_stylizer_graph.cc"], srcs = ["face_stylizer_graph.cc"],
deps = [ deps = [
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_cropping_calculator", "//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto", "//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
"//mediapipe/calculators/image:warp_affine_calculator", "//mediapipe/calculators/image:warp_affine_calculator",
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto", "//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:face_to_rect_calculator",
"//mediapipe/calculators/util:from_image_calculator", "//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:inverse_matrix_calculator", "//mediapipe/calculators/util:inverse_matrix_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
"//mediapipe/calculators/util:to_image_calculator", "//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1, alwayslink = 1,
@ -58,6 +71,7 @@ cc_library(
":face_stylizer_graph", # buildcleaner:keep ":face_stylizer_graph", # buildcleaner:keep
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
@ -113,7 +114,10 @@ absl::StatusOr<std::unique_ptr<FaceStylizer>> FaceStylizer::Create(
Packet stylized_image_packet = Packet stylized_image_packet =
status_or_packets.value()[kStylizedImageName]; status_or_packets.value()[kStylizedImageName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(stylized_image_packet.Get<Image>(), result_callback(
stylized_image_packet.IsEmpty()
? std::nullopt
: std::optional<Image>(stylized_image_packet.Get<Image>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
stylized_image_packet.Timestamp().Value() / stylized_image_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
@ -128,7 +132,7 @@ absl::StatusOr<std::unique_ptr<FaceStylizer>> FaceStylizer::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<Image> FaceStylizer::Stylize( absl::StatusOr<std::optional<Image>> FaceStylizer::Stylize(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -144,10 +148,13 @@ absl::StatusOr<Image> FaceStylizer::Stylize(
ProcessImageData( ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}})); {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kStylizedImageName].Get<Image>(); return output_packets[kStylizedImageName].IsEmpty()
? std::nullopt
: std::optional<Image>(
output_packets[kStylizedImageName].Get<Image>());
} }
absl::StatusOr<Image> FaceStylizer::StylizeForVideo( absl::StatusOr<std::optional<Image>> FaceStylizer::StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -167,7 +174,10 @@ absl::StatusOr<Image> FaceStylizer::StylizeForVideo(
{kNormRectName, {kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kStylizedImageName].Get<Image>(); return output_packets[kStylizedImageName].IsEmpty()
? std::nullopt
: std::optional<Image>(
output_packets[kStylizedImageName].Get<Image>());
} }
absl::Status FaceStylizer::StylizeAsync( absl::Status FaceStylizer::StylizeAsync(

View File

@ -53,7 +53,8 @@ struct FaceStylizerOptions {
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<mediapipe::Image>, const Image&, int64_t)> std::function<void(absl::StatusOr<std::optional<mediapipe::Image>>,
const Image&, int64_t)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -81,10 +82,12 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// running mode. // running mode.
// //
// The input image can be of any size with format RGB or RGBA. // The input image can be of any size with format RGB or RGBA.
// To ensure that the output image has reasonable quality, the stylized output // When no face is detected on the input image, the method returns a
// image size is the smaller of the model output size and the size of the // std::nullopt. Otherwise, returns the stylized image of the most visible
// 'region_of_interest' specified in 'image_processing_options'. // face. To ensure that the output image has reasonable quality, the stylized
absl::StatusOr<mediapipe::Image> Stylize( // output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -106,10 +109,12 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
// To ensure that the output image has reasonable quality, the stylized output // When no face is detected on the input image, the method returns a
// image size is the smaller of the model output size and the size of the // std::nullopt. Otherwise, returns the stylized image of the most visible
// 'region_of_interest' specified in 'image_processing_options'. // face. To ensure that the output image has reasonable quality, the stylized
absl::StatusOr<mediapipe::Image> StylizeForVideo( // output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -136,8 +141,11 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// increasing. // increasing.
// //
// The "result_callback" provides: // The "result_callback" provides:
// - The stylized image which size is the smaller of the model output size // - When no face is detected on the input image, the method returns a
// and the size of the 'region_of_interest' specified in // std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the
// stylized output image size is the smaller of the model output size and
// the size of the 'region_of_interest' specified in
// 'image_processing_options'. // 'image_processing_options'.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms, absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,

View File

@ -16,20 +16,30 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/memory/memory.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h" #include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h" #include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options.pb.h"
@ -45,17 +55,29 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::TensorsToImageCalculatorOptions; using ::mediapipe::tasks::TensorsToImageCalculatorOptions;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::core::proto::ExternalFile;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::face_landmarker::proto::
FaceLandmarkerGraphOptions;
using ::mediapipe::tasks::vision::face_stylizer::proto:: using ::mediapipe::tasks::vision::face_stylizer::proto::
FaceStylizerGraphOptions; FaceStylizerGraphOptions;
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
constexpr char kFaceLandmarksDetectorTFLiteName[] =
"face_landmarks_detector.tflite";
constexpr char kFaceStylizerTFLiteName[] = "face_stylizer.tflite";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU"; constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kMatrixTag[] = "MATRIX"; constexpr char kMatrixTag[] = "MATRIX";
constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kSizeTag[] = "SIZE";
constexpr char kStylizedImageTag[] = "STYLIZED_IMAGE"; constexpr char kStylizedImageTag[] = "STYLIZED_IMAGE";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
@ -66,6 +88,76 @@ struct FaceStylizerOutputStreams {
Source<Image> original_image; Source<Image> original_image;
}; };
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
FaceStylizerGraphOptions* options,
ExternalFile* face_stylizer_external_file,
bool is_copy) {
auto* face_detector_graph_options =
options->mutable_face_landmarker_graph_options()
->mutable_face_detector_graph_options();
if (!face_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_detector_file,
resources.GetFile(kFaceDetectorTFLiteName));
SetExternalFile(face_detector_file,
face_detector_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
}
face_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
face_detector_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
auto* face_landmarks_detector_graph_options =
options->mutable_face_landmarker_graph_options()
->mutable_face_landmarks_detector_graph_options();
if (!face_landmarks_detector_graph_options->base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_landmarks_detector_file,
resources.GetFile(kFaceLandmarksDetectorTFLiteName));
SetExternalFile(
face_landmarks_detector_file,
face_landmarks_detector_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
}
face_landmarks_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
face_landmarks_detector_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto face_stylizer_file,
resources.GetFile(kFaceStylizerTFLiteName));
SetExternalFile(face_stylizer_file, face_stylizer_external_file, is_copy);
return absl::OkStatus();
}
void ConfigureSplitNormalizedLandmarkListVectorCalculator(
mediapipe::SplitVectorCalculatorOptions* options) {
auto* vector_range = options->add_ranges();
vector_range->set_begin(0);
vector_range->set_end(1);
options->set_element_only(true);
}
void ConfigureLandmarksToDetectionCalculator(
LandmarksToDetectionCalculatorOptions* options) {
// left eye
options->add_selected_landmark_indices(33);
// left eye
options->add_selected_landmark_indices(133);
// right eye
options->add_selected_landmark_indices(263);
// right eye
options->add_selected_landmark_indices(362);
// mouth
options->add_selected_landmark_indices(61);
// mouth
options->add_selected_landmark_indices(291);
}
void ConfigureTensorsToImageCalculator( void ConfigureTensorsToImageCalculator(
const ImageToTensorCalculatorOptions& image_to_tensor_options, const ImageToTensorCalculatorOptions& image_to_tensor_options,
TensorsToImageCalculatorOptions* tensors_to_image_options) { TensorsToImageCalculatorOptions* tensors_to_image_options) {
@ -89,7 +181,7 @@ void ConfigureTensorsToImageCalculator(
} // namespace } // namespace
// A "mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph" performs face // A "mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph" performs face
// stylization. // stylization on the detected face image.
// //
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
@ -114,7 +206,7 @@ void ConfigureTensorsToImageCalculator(
// { // {
// base_options { // base_options {
// model_asset { // model_asset {
// file_name: "face_stylization.tflite" // file_name: "face_stylizer.task"
// } // }
// } // }
// } // }
@ -124,25 +216,94 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(
CreateModelResources<FaceStylizerGraphOptions>(sc)); const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<FaceStylizerGraphOptions>(sc));
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
auto face_stylizer_external_file = absl::make_unique<ExternalFile>();
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<FaceStylizerGraphOptions>(),
face_stylizer_external_file.get(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto face_landmark_lists,
BuildFaceStylizerGraph( BuildFaceLandmarkerGraph(
sc->Options<FaceStylizerGraphOptions>(), *model_resources, sc->MutableOptions<FaceStylizerGraphOptions>()
->mutable_face_landmarker_graph_options(),
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources(sc, std::move(face_stylizer_external_file)));
ASSIGN_OR_RETURN(
auto output_streams,
BuildFaceStylizerGraph(sc->Options<FaceStylizerGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
face_landmark_lists, graph));
output_streams.stylized_image >> graph[Output<Image>(kStylizedImageTag)]; output_streams.stylized_image >> graph[Output<Image>(kStylizedImageTag)];
output_streams.original_image >> graph[Output<Image>(kImageTag)]; output_streams.original_image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<Source<std::vector<NormalizedLandmarkList>>>
BuildFaceLandmarkerGraph(FaceLandmarkerGraphOptions* face_landmarker_options,
Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
auto& landmarker_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph");
if (face_landmarker_options->face_detector_graph_options()
.has_num_faces() &&
face_landmarker_options->face_detector_graph_options().num_faces() !=
1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Face stylizer currently only supports one face.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
face_landmarker_options->mutable_face_detector_graph_options()
->set_num_faces(1);
image_in >> landmarker_graph.In(kImageTag);
norm_rect_in >> landmarker_graph.In(kNormRectTag);
landmarker_graph.GetOptions<FaceLandmarkerGraphOptions>().Swap(
face_landmarker_options);
return landmarker_graph.Out(kNormLandmarksTag)
.Cast<std::vector<NormalizedLandmarkList>>();
}
absl::StatusOr<FaceStylizerOutputStreams> BuildFaceStylizerGraph( absl::StatusOr<FaceStylizerOutputStreams> BuildFaceStylizerGraph(
const FaceStylizerGraphOptions& task_options, const FaceStylizerGraphOptions& task_options,
const ModelResources& model_resources, Source<Image> image_in, const ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<std::vector<NormalizedLandmarkList>> face_landmark_lists,
Graph& graph) {
auto& split_face_landmark_list =
graph.AddNode("SplitNormalizedLandmarkListVectorCalculator");
ConfigureSplitNormalizedLandmarkListVectorCalculator(
&split_face_landmark_list
.GetOptions<mediapipe::SplitVectorCalculatorOptions>());
face_landmark_lists >> split_face_landmark_list.In("");
auto face_landmarks = split_face_landmark_list.Out("");
auto& landmarks_to_detection =
graph.AddNode("LandmarksToDetectionCalculator");
ConfigureLandmarksToDetectionCalculator(
&landmarks_to_detection
.GetOptions<LandmarksToDetectionCalculatorOptions>());
face_landmarks >> landmarks_to_detection.In(kNormLandmarksTag);
auto face_detection = landmarks_to_detection.Out(kDetectionTag);
auto& get_image_size = graph.AddNode("ImagePropertiesCalculator");
image_in >> get_image_size.In(kImageTag);
auto image_size = get_image_size.Out(kSizeTag);
auto& face_to_rect = graph.AddNode("FaceToRectCalculator");
face_detection >> face_to_rect.In(kDetectionTag);
image_size >> face_to_rect.In(kImageSizeTag);
auto face_rect = face_to_rect.Out(kNormRectTag);
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
// stream. // stream.
auto& preprocessing = graph.AddNode( auto& preprocessing = graph.AddNode(
@ -163,10 +324,9 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
image_to_tensor_options.set_border_mode( image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag); face_rect >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out(kTensorsTag); auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto transform_matrix = preprocessing.Out(kMatrixTag); auto transform_matrix = preprocessing.Out(kMatrixTag);
auto image_size = preprocessing.Out(kImageSizeTag);
// Adds inference subgraph and connects its input stream to the output // Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator. // tensors produced by the ImageToTensorCalculator.
@ -206,7 +366,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
// performing extra rotation. // performing extra rotation.
auto& strip_rotation = auto& strip_rotation =
graph.AddNode("mediapipe.tasks.StripRotationCalculator"); graph.AddNode("mediapipe.tasks.StripRotationCalculator");
norm_rect_in >> strip_rotation.In(kNormRectTag); face_rect >> strip_rotation.In(kNormRectTag);
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag); auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
auto& from_image = graph.AddNode("FromImageCalculator"); auto& from_image = graph.AddNode("FromImageCalculator");
image_to_crop >> from_image.In(kImageTag); image_to_crop >> from_image.In(kImageTag);

View File

@ -27,5 +27,6 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_proto",
], ],
) )

View File

@ -20,6 +20,7 @@ package mediapipe.tasks.vision.face_stylizer.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto"; import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.facestylizer.proto"; option java_package = "com.google.mediapipe.tasks.vision.facestylizer.proto";
option java_outer_classname = "FaceStylizerGraphOptionsProto"; option java_outer_classname = "FaceStylizerGraphOptionsProto";
@ -31,4 +32,8 @@ message FaceStylizerGraphOptions {
// Base options for configuring face stylizer, such as specifying the TfLite // Base options for configuring face stylizer, such as specifying the TfLite
// model file with metadata, accelerator options, etc. // model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1; optional core.proto.BaseOptions base_options = 1;
// Options for face landmarker graph.
optional vision.face_landmarker.proto.FaceLandmarkerGraphOptions
face_landmarker_graph_options = 2;
} }

View File

@ -105,6 +105,12 @@ public final class FaceStylizer extends BaseVisionTaskApi {
public FaceStylizerResult convertToTaskResult(List<Packet> packets) public FaceStylizerResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException { throws MediaPipeException {
Packet packet = packets.get(IMAGE_OUT_STREAM_INDEX); Packet packet = packets.get(IMAGE_OUT_STREAM_INDEX);
if (packet.isEmpty()) {
return FaceStylizerResult.create(
Optional.empty(),
BaseVisionTaskApi.generateResultTimestampMs(
stylizerOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
}
int width = PacketGetter.getImageWidth(packet); int width = PacketGetter.getImageWidth(packet);
int height = PacketGetter.getImageHeight(packet); int height = PacketGetter.getImageHeight(packet);
int numChannels = PacketGetter.getImageNumChannels(packet); int numChannels = PacketGetter.getImageNumChannels(packet);
@ -134,7 +140,7 @@ public final class FaceStylizer extends BaseVisionTaskApi {
new ByteBufferImageBuilder(imageBuffer, width, height, imageFormat); new ByteBufferImageBuilder(imageBuffer, width, height, imageFormat);
return FaceStylizerResult.create( return FaceStylizerResult.create(
imageBuilder.build(), Optional.of(imageBuilder.build()),
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
stylizerOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX))); stylizerOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
} }
@ -146,6 +152,10 @@ public final class FaceStylizer extends BaseVisionTaskApi {
.build(); .build();
} }
}); });
// Empty output image packets indicates that no face stylization is applied.
if (stylizerOptions.runningMode() != RunningMode.LIVE_STREAM) {
handler.setHandleTimestampBoundChanges(true);
}
stylizerOptions.resultListener().ifPresent(handler::setResultListener); stylizerOptions.resultListener().ifPresent(handler::setResultListener);
stylizerOptions.errorListener().ifPresent(handler::setErrorListener); stylizerOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner = TaskRunner runner =
@ -210,7 +220,7 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size, To ensure that the output image has reasonable quality, * <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the * the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* *
@ -271,7 +281,7 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size, To ensure that the output image has reasonable quality, * <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the * the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* *
@ -336,7 +346,7 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size, To ensure that the output image has reasonable quality, * <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the * the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* *
@ -404,7 +414,7 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size, To ensure that the output image has reasonable quality, * <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the * the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* *
@ -465,7 +475,7 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size, To ensure that the output image has reasonable quality, * <p>The input image can be of any size. To ensure that the output image has reasonable quality,
* the stylized output image size is the smaller of the model output size and the size of the * the stylized output image size is the smaller of the model output size and the size of the
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
* *

View File

@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.facestylizer;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import java.util.Optional;
/** Represents the stylized image generated by {@link FaceStylizer}. */ /** Represents the stylized image generated by {@link FaceStylizer}. */
@AutoValue @AutoValue
@ -25,14 +26,15 @@ public abstract class FaceStylizerResult implements TaskResult {
/** /**
* Creates an {@link FaceStylizerResult} instance from a MPImage. * Creates an {@link FaceStylizerResult} instance from a MPImage.
* *
* @param stylizedImage an MPImage representing the stylized face. * @param stylizedImage an {@link Optional} MPImage representing the stylized image of the most
* visible face. Empty if no face is detected on the input image.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
*/ */
public static FaceStylizerResult create(MPImage stylizedImage, long timestampMs) { public static FaceStylizerResult create(Optional<MPImage> stylizedImage, long timestampMs) {
return new AutoValue_FaceStylizerResult(stylizedImage, timestampMs); return new AutoValue_FaceStylizerResult(stylizedImage, timestampMs);
} }
public abstract MPImage stylizedImage(); public abstract Optional<MPImage> stylizedImage();
@Override @Override
public abstract long timestampMs(); public abstract long timestampMs();

View File

@ -20,7 +20,6 @@ import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager; import android.content.res.AssetManager;
import android.graphics.BitmapFactory; import android.graphics.BitmapFactory;
import android.graphics.RectF; import android.graphics.RectF;
import android.util.Pair;
import androidx.test.core.app.ApplicationProvider; import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
@ -41,16 +40,10 @@ import org.junit.runners.Suite.SuiteClasses;
@RunWith(Suite.class) @RunWith(Suite.class)
@SuiteClasses({FaceStylizerTest.General.class, FaceStylizerTest.RunningModeTest.class}) @SuiteClasses({FaceStylizerTest.General.class, FaceStylizerTest.RunningModeTest.class})
public class FaceStylizerTest { public class FaceStylizerTest {
private static final String modelFile = "face_stylization_dummy.tflite"; private static final String modelFile = "face_stylizer.task";
private static final String testImage = "portrait.jpg"; private static final String largeFaceTestImage = "portrait.jpg";
private static final int modelImageSize = 512; private static final String smallFaceTestImage = "portrait_small.jpg";
private static final int modelImageSize = 256;
public Pair<Integer, Integer> getRectPixelSize(MPImage originalImage, RectF rect) {
int width = originalImage.getWidth();
int height = originalImage.getHeight();
return new Pair<>(
(int) ((rect.right - rect.left) * width), (int) ((rect.bottom - rect.top) * height));
}
@RunWith(AndroidJUnit4.class) @RunWith(AndroidJUnit4.class)
public static final class General extends FaceStylizerTest { public static final class General extends FaceStylizerTest {
@ -131,17 +124,19 @@ public class FaceStylizerTest {
MediaPipeException.class, MediaPipeException.class,
() -> () ->
faceStylizer.stylizeForVideo( faceStylizer.stylizeForVideo(
getImageFromAsset(testImage), /* timestampsMs= */ 0)); getImageFromAsset(largeFaceTestImage), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> faceStylizer.stylizeAsync(getImageFromAsset(testImage), /* timestampsMs= */ 0)); () ->
faceStylizer.stylizeAsync(
getImageFromAsset(largeFaceTestImage), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> faceStylizer.stylizeWithResultListener(getImageFromAsset(testImage))); () -> faceStylizer.stylizeWithResultListener(getImageFromAsset(largeFaceTestImage)));
assertThat(exception) assertThat(exception)
.hasMessageThat() .hasMessageThat()
.contains("ResultListener is not set in the FaceStylizerOptions"); .contains("ResultListener is not set in the FaceStylizerOptions");
@ -159,19 +154,22 @@ public class FaceStylizerTest {
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception = MediaPipeException exception =
assertThrows( assertThrows(
MediaPipeException.class, () -> faceStylizer.stylize(getImageFromAsset(testImage))); MediaPipeException.class,
() -> faceStylizer.stylize(getImageFromAsset(largeFaceTestImage)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> faceStylizer.stylizeAsync(getImageFromAsset(testImage), /* timestampsMs= */ 0)); () ->
faceStylizer.stylizeAsync(
getImageFromAsset(largeFaceTestImage), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> () ->
faceStylizer.stylizeForVideoWithResultListener( faceStylizer.stylizeForVideoWithResultListener(
getImageFromAsset(testImage), /* timestampsMs= */ 0)); getImageFromAsset(largeFaceTestImage), /* timestampsMs= */ 0));
assertThat(exception) assertThat(exception)
.hasMessageThat() .hasMessageThat()
.contains("ResultListener is not set in the FaceStylizerOptions"); .contains("ResultListener is not set in the FaceStylizerOptions");
@ -191,14 +189,14 @@ public class FaceStylizerTest {
MediaPipeException exception = MediaPipeException exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> faceStylizer.stylizeWithResultListener(getImageFromAsset(testImage))); () -> faceStylizer.stylizeWithResultListener(getImageFromAsset(largeFaceTestImage)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception = exception =
assertThrows( assertThrows(
MediaPipeException.class, MediaPipeException.class,
() -> () ->
faceStylizer.stylizeForVideoWithResultListener( faceStylizer.stylizeForVideoWithResultListener(
getImageFromAsset(testImage), /* timestampsMs= */ 0)); getImageFromAsset(largeFaceTestImage), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
} }
@ -213,18 +211,33 @@ public class FaceStylizerTest {
faceStylizer = faceStylizer =
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MPImage inputImage = getImageFromAsset(testImage); MPImage inputImage = getImageFromAsset(largeFaceTestImage);
int inputWidth = inputImage.getWidth();
int inputHeight = inputImage.getHeight();
float inputAspectRatio = (float) inputWidth / inputHeight;
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage); FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
MPImage stylizedImage = actualResult.stylizedImage(); MPImage stylizedImage = actualResult.stylizedImage().get();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()).isEqualTo((int) (modelImageSize * inputAspectRatio)); assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
} }
@Test
public void stylizer_succeedsWithSmallImage() throws Exception {
FaceStylizerOptions options =
FaceStylizerOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(modelFile).build())
.setRunningMode(RunningMode.IMAGE)
.build();
faceStylizer =
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MPImage inputImage = getImageFromAsset(smallFaceTestImage);
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
MPImage stylizedImage = actualResult.stylizedImage().get();
assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()).isEqualTo(83);
assertThat(stylizedImage.getHeight()).isEqualTo(83);
}
@Test @Test
public void stylizer_succeedsWithRegionOfInterest() throws Exception { public void stylizer_succeedsWithRegionOfInterest() throws Exception {
FaceStylizerOptions options = FaceStylizerOptions options =
@ -235,7 +248,7 @@ public class FaceStylizerTest {
faceStylizer = faceStylizer =
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MPImage inputImage = getImageFromAsset(testImage); MPImage inputImage = getImageFromAsset(largeFaceTestImage);
// Region-of-interest around the face. // Region-of-interest around the face.
RectF roi = RectF roi =
@ -244,47 +257,57 @@ public class FaceStylizerTest {
ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage, imageProcessingOptions); FaceStylizerResult actualResult = faceStylizer.stylize(inputImage, imageProcessingOptions);
var rectPixelSize = getRectPixelSize(inputImage, roi); MPImage stylizedImage = actualResult.stylizedImage().get();
MPImage stylizedImage = actualResult.stylizedImage();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()).isEqualTo(rectPixelSize.first); assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
assertThat(stylizedImage.getHeight()).isEqualTo(rectPixelSize.second); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
}
@Test
public void stylizer_succeedsWithNoFaceDetected() throws Exception {
FaceStylizerOptions options =
FaceStylizerOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(modelFile).build())
.setRunningMode(RunningMode.IMAGE)
.build();
faceStylizer =
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MPImage inputImage = getImageFromAsset(largeFaceTestImage);
// Region-of-interest that doesn't contain a human face.
RectF roi =
new RectF(/* left= */ 0.1f, /* top= */ 0.1f, /* right= */ 0.2f, /* bottom= */ 0.2f);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage, imageProcessingOptions);
assertThat(actualResult.stylizedImage()).isNotNull();
assertThat(actualResult.stylizedImage().isPresent()).isFalse();
} }
@Test @Test
public void stylizer_successWithImageModeWithResultListener() throws Exception { public void stylizer_successWithImageModeWithResultListener() throws Exception {
MPImage inputImage = getImageFromAsset(testImage); MPImage inputImage = getImageFromAsset(largeFaceTestImage);
int inputWidth = inputImage.getWidth();
int inputHeight = inputImage.getHeight();
float inputAspectRatio = (float) inputWidth / inputHeight;
FaceStylizerOptions options = FaceStylizerOptions options =
FaceStylizerOptions.builder() FaceStylizerOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(modelFile).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(modelFile).build())
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.setResultListener( .setResultListener(
(result, originalImage) -> { (result, originalImage) -> {
assertThat(originalImage).isEqualTo(inputImage); MPImage stylizedImage = result.stylizedImage().get();
MPImage stylizedImage = result.stylizedImage();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()) assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
.isEqualTo(modelImageSize * inputAspectRatio);
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
}) })
.build(); .build();
faceStylizer = faceStylizer =
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
faceStylizer.stylizeWithResultListener(getImageFromAsset(testImage)); faceStylizer.stylizeWithResultListener(getImageFromAsset(largeFaceTestImage));
} }
@Test @Test
public void stylizer_successWithVideoMode() throws Exception { public void stylizer_successWithVideoMode() throws Exception {
MPImage inputImage = getImageFromAsset(testImage); MPImage inputImage = getImageFromAsset(largeFaceTestImage);
int inputWidth = inputImage.getWidth();
int inputHeight = inputImage.getHeight();
float inputAspectRatio = (float) inputWidth / inputHeight;
FaceStylizerOptions options = FaceStylizerOptions options =
FaceStylizerOptions.builder() FaceStylizerOptions.builder()
@ -295,21 +318,19 @@ public class FaceStylizerTest {
FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); FaceStylizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
FaceStylizerResult actualResult = FaceStylizerResult actualResult =
faceStylizer.stylizeForVideo(getImageFromAsset(testImage), /* timestampsMs= */ i); faceStylizer.stylizeForVideo(
getImageFromAsset(largeFaceTestImage), /* timestampsMs= */ i);
MPImage stylizedImage = actualResult.stylizedImage(); MPImage stylizedImage = actualResult.stylizedImage().get();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()).isEqualTo((int) (modelImageSize * inputAspectRatio)); assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
} }
} }
@Test @Test
public void stylizer_successWithVideoModeWithResultListener() throws Exception { public void stylizer_successWithVideoModeWithResultListener() throws Exception {
MPImage inputImage = getImageFromAsset(testImage); MPImage inputImage = getImageFromAsset(largeFaceTestImage);
int inputWidth = inputImage.getWidth();
int inputHeight = inputImage.getHeight();
float inputAspectRatio = (float) inputWidth / inputHeight;
FaceStylizerOptions options = FaceStylizerOptions options =
FaceStylizerOptions.builder() FaceStylizerOptions.builder()
@ -317,12 +338,9 @@ public class FaceStylizerTest {
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.setResultListener( .setResultListener(
(result, originalImage) -> { (result, originalImage) -> {
assertThat(originalImage).isEqualTo(inputImage); MPImage stylizedImage = result.stylizedImage().get();
MPImage stylizedImage = result.stylizedImage();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()) assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
.isEqualTo((int) (modelImageSize * inputAspectRatio));
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
}) })
.build(); .build();
@ -335,10 +353,7 @@ public class FaceStylizerTest {
@Test @Test
public void stylizer_successWithLiveStreamMode() throws Exception { public void stylizer_successWithLiveStreamMode() throws Exception {
MPImage inputImage = getImageFromAsset(testImage); MPImage inputImage = getImageFromAsset(largeFaceTestImage);
int inputWidth = inputImage.getWidth();
int inputHeight = inputImage.getHeight();
float inputAspectRatio = (float) inputWidth / inputHeight;
FaceStylizerOptions options = FaceStylizerOptions options =
FaceStylizerOptions.builder() FaceStylizerOptions.builder()
@ -346,10 +361,9 @@ public class FaceStylizerTest {
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(result, originalImage) -> { (result, originalImage) -> {
MPImage stylizedImage = result.stylizedImage(); MPImage stylizedImage = result.stylizedImage().get();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()) assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
.isEqualTo((int) (modelImageSize * inputAspectRatio));
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
}) })
.build(); .build();
@ -363,7 +377,7 @@ public class FaceStylizerTest {
@Test @Test
public void stylizer_failsWithOutOfOrderInputTimestamps() throws Exception { public void stylizer_failsWithOutOfOrderInputTimestamps() throws Exception {
MPImage image = getImageFromAsset(testImage); MPImage image = getImageFromAsset(largeFaceTestImage);
FaceStylizerOptions options = FaceStylizerOptions options =
FaceStylizerOptions.builder() FaceStylizerOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(modelFile).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(modelFile).build())

View File

@ -128,6 +128,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
return return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME] stylized_image_packet = output_packets[_STYLIZED_IMAGE_NAME]
if stylized_image_packet.is_empty():
options.result_callback(
None,
image,
stylized_image_packet.timestamp.value
// _MICRO_SECONDS_PER_MILLISECOND,
)
stylized_image = packet_getter.get_image(stylized_image_packet) stylized_image = packet_getter.get_image(stylized_image_packet)
options.result_callback( options.result_callback(
@ -177,7 +185,8 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
image_processing_options: Options for image processing. image_processing_options: Options for image processing.
Returns: Returns:
The stylized image. The stylized image of the most visible face. None if no face is detected
on the input image.
Raises: Raises:
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
@ -191,6 +200,8 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
if output_packets[_STYLIZED_IMAGE_NAME].is_empty():
return None
return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME]) return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME])
def stylize_for_video( def stylize_for_video(
@ -216,7 +227,8 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
image_processing_options: Options for image processing. image_processing_options: Options for image processing.
Returns: Returns:
The stylized image. The stylized image of the most visible face. None if no face is detected
on the input image.
Raises: Raises:
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
@ -232,6 +244,8 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
if output_packets[_STYLIZED_IMAGE_NAME].is_empty():
return None
return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME]) return packet_getter.get_image(output_packets[_STYLIZED_IMAGE_NAME])
def stylize_async( def stylize_async(
@ -257,7 +271,8 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
the `region_of_interest` specified in `image_processing_options`. the `region_of_interest` specified in `image_processing_options`.
The `result_callback` provides: The `result_callback` provides:
- The stylized image. - The stylized image of the most visible face. None if no face is detected
on the input image.
- The input image that the face stylizer runs on. - The input image that the face stylizer runs on.
- The input timestamp in milliseconds. - The input timestamp in milliseconds.