Allow FaceStylizerGraph to miss base options.

Fix the color issue when the graph is running on gpu and "face alignment only" mode.

PiperOrigin-RevId: 534912498
This commit is contained in:
Jiuqiang Tang 2023-05-24 11:12:20 -07:00 committed by Copybara-Service
parent acfaf3f1b6
commit 7facc925ba

View File

@ -231,18 +231,26 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
SubgraphContext* sc) override { SubgraphContext* sc) override {
bool output_stylized = HasOutput(sc->OriginalNode(), kStylizedImageTag); bool output_stylized = HasOutput(sc->OriginalNode(), kStylizedImageTag);
bool output_alignment = HasOutput(sc->OriginalNode(), kFaceAlignmentTag); bool output_alignment = HasOutput(sc->OriginalNode(), kFaceAlignmentTag);
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<FaceStylizerGraphOptions>(sc));
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
auto face_stylizer_external_file = absl::make_unique<ExternalFile>(); auto face_stylizer_external_file = absl::make_unique<ExternalFile>();
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( if (sc->Options<FaceStylizerGraphOptions>().has_base_options()) {
*model_asset_bundle_resources, ASSIGN_OR_RETURN(
sc->MutableOptions<FaceStylizerGraphOptions>(), const auto* model_asset_bundle_resources,
output_stylized ? face_stylizer_external_file.get() : nullptr, CreateModelAssetBundleResources<FaceStylizerGraphOptions>(sc));
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) // Copies the file content instead of passing the pointer of file in
.IsAvailable())); // memory if the subgraph model resource service is not available.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<FaceStylizerGraphOptions>(),
output_stylized ? face_stylizer_external_file.get() : nullptr,
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
} else if (output_stylized) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Face stylizer must specify its base options when the "
"\"STYLIZED_IMAGE\" output stream is connected.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto face_landmark_lists, auto face_landmark_lists,
@ -347,7 +355,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator");
auto& image_to_tensor_options = auto& image_to_tensor_options =
image_to_tensor.GetOptions<ImageToTensorCalculatorOptions>(); image_to_tensor.GetOptions<ImageToTensorCalculatorOptions>();
image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1); image_to_tensor_options.mutable_output_tensor_float_range()->set_min(0);
image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1); image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1);
image_to_tensor_options.set_output_tensor_width(kFaceAlignmentOutputSize); image_to_tensor_options.set_output_tensor_width(kFaceAlignmentOutputSize);
image_to_tensor_options.set_output_tensor_height( image_to_tensor_options.set_output_tensor_height(
@ -363,7 +371,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
graph.AddNode("mediapipe.tasks.TensorsToImageCalculator"); graph.AddNode("mediapipe.tasks.TensorsToImageCalculator");
auto& tensors_to_image_options = auto& tensors_to_image_options =
tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>(); tensors_to_image.GetOptions<TensorsToImageCalculatorOptions>();
tensors_to_image_options.mutable_input_tensor_float_range()->set_min(-1); tensors_to_image_options.mutable_input_tensor_float_range()->set_min(0);
tensors_to_image_options.mutable_input_tensor_float_range()->set_max(1); tensors_to_image_options.mutable_input_tensor_float_range()->set_max(1);
face_alignment_image >> tensors_to_image.In(kTensorsTag); face_alignment_image >> tensors_to_image.In(kTensorsTag);
face_alignment = tensors_to_image.Out(kImageTag).Cast<Image>(); face_alignment = tensors_to_image.Out(kImageTag).Cast<Image>();