From e18e749e3ef52e057860e6498e8c4c51f5bbccfc Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 29 Aug 2023 15:02:23 -0700 Subject: [PATCH] Internal update PiperOrigin-RevId: 561148365 --- .../tasks/cc/vision/image_generator/BUILD | 136 ++++ .../conditioned_image_graph.cc | 458 ++++++++++++ .../conditioned_image_graph_test.cc | 147 ++++ .../cc/vision/image_generator/diffuser/BUILD | 70 ++ .../image_generator/diffuser/diffuser_gpu.h | 88 +++ .../diffusion_plugins_output_calculator.cc | 67 ++ .../stable_diffusion_iterate_calculator.cc | 278 ++++++++ .../stable_diffusion_iterate_calculator.proto | 84 +++ .../vision/image_generator/image_generator.cc | 397 +++++++++++ .../vision/image_generator/image_generator.h | 157 +++++ .../image_generator/image_generator_graph.cc | 361 ++++++++++ .../image_generator/image_generator_result.h | 41 ++ .../cc/vision/image_generator/proto/BUILD | 52 ++ .../conditioned_image_graph_options.proto | 66 ++ .../proto/control_plugin_graph_options.proto | 34 + .../proto/image_generator_graph_options.proto | 35 + .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 50 ++ .../com/google/mediapipe/tasks/vision/BUILD | 5 +- .../vision/imagegenerator/AndroidManifest.xml | 8 + .../tasks/vision/imagegenerator/BUILD | 84 +++ .../vision/imagegenerator/ImageGenerator.java | 660 ++++++++++++++++++ .../imagegenerator/ImageGeneratorResult.java | 44 ++ 22 files changed, 3321 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/cc/vision/image_generator/BUILD create mode 100644 mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc create mode 100644 mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc create mode 100644 mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD create mode 100644 mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h create mode 100644 mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc create mode 100644 mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc create mode 100644 mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto create mode 100644 mediapipe/tasks/cc/vision/image_generator/image_generator.cc create mode 100644 mediapipe/tasks/cc/vision/image_generator/image_generator.h create mode 100644 mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc create mode 100644 mediapipe/tasks/cc/vision/image_generator/image_generator_result.h create mode 100644 mediapipe/tasks/cc/vision/image_generator/proto/BUILD create mode 100644 mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto create mode 100644 mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto create mode 100644 mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGeneratorResult.java diff --git a/mediapipe/tasks/cc/vision/image_generator/BUILD b/mediapipe/tasks/cc/vision/image_generator/BUILD new file mode 100644 index 000000000..71b8230ae --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/BUILD @@ -0,0 +1,136 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +cc_library( + name = "conditioned_image_graph", + srcs = ["conditioned_image_graph.cc"], + deps = [ + "//mediapipe/calculators/core:get_vector_item_calculator", + "//mediapipe/calculators/core:get_vector_item_calculator_cc_proto", + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator_cc_proto", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator", + "//mediapipe/calculators/util:landmarks_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_connections", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:image_frame_util", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "image_generator_graph", + srcs = ["image_generator_graph.cc"], + deps = [ + ":conditioned_image_graph", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/image:image_transformation_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:stream_handler_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container", + "//mediapipe/framework/tool:switch_container_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:diffusion_plugins_output_calculator", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator", + "//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto", + "//mediapipe/util:graph_builder_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "image_generator_result", + hdrs = ["image_generator_result.h"], + deps = ["//mediapipe/framework/formats:image"], +) + +cc_library( + name = "image_generator", + srcs = ["image_generator.cc"], + hdrs = ["image_generator.h"], + deps = [ + ":image_generator_graph", + ":image_generator_result", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) diff --git a/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc new file mode 100644 index 000000000..c85fe981c --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph.cc @@ -0,0 +1,458 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/core/get_vector_item_calculator.h" +#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" +#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.h" +#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/image_frame_util.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +namespace internal { + +// Helper postprocessing calculator for depth condition type to scale raw depth +// inference result to 0-255 uint8. +class DepthImagePostprocessingCalculator : public api2::Node { + public: + static constexpr api2::Input kImageIn{"IMAGE"}; + static constexpr api2::Output kImageOut{"IMAGE"}; + + MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut); + + absl::Status Process(CalculatorContext* cc) final { + if (kImageIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + Image raw_depth_image = kImageIn(cc).Get(); + cv::Mat raw_depth_mat = mediapipe::formats::MatView( + raw_depth_image.GetImageFrameSharedPtr().get()); + cv::Mat depth_mat; + cv::normalize(raw_depth_mat, depth_mat, 255, 0, cv::NORM_MINMAX); + depth_mat.convertTo(depth_mat, CV_8UC3, 1, 0); + cv::cvtColor(depth_mat, depth_mat, cv::COLOR_GRAY2RGB); + // Acquires the cv::Mat data and assign to the image frame. + ImageFrameSharedPtr depth_image_frame_ptr = std::make_shared( + mediapipe::ImageFormat::SRGB, depth_mat.cols, depth_mat.rows, + depth_mat.step, depth_mat.data, + [depth_mat](uint8_t[]) { depth_mat.~Mat(); }); + Image depth_image(depth_image_frame_ptr); + kImageOut(cc).Send(depth_image); + return absl::OkStatus(); + } +}; + +// NOLINTBEGIN: Node registration doesn't work when part of calculator name is +// moved to next line. +// clang-format off +MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::DepthImagePostprocessingCalculator); +// clang-format on +// NOLINTEND + +// Calculator to detect edges in the image with OpenCV Canny edge detection. +class CannyEdgeCalculator : public api2::Node { + public: + static constexpr api2::Input kImageIn{"IMAGE"}; + static constexpr api2::Output kImageOut{"IMAGE"}; + + MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut); + + absl::Status Process(CalculatorContext* cc) final { + if (kImageIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + Image input_image = kImageIn(cc).Get(); + cv::Mat input_image_mat = + mediapipe::formats::MatView(input_image.GetImageFrameSharedPtr().get()); + const auto& options = cc->Options< + proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>(); + cv::Mat lumincance; + cv::cvtColor(input_image_mat, lumincance, cv::COLOR_RGB2GRAY); + cv::Mat edges_mat; + cv::Canny(lumincance, edges_mat, options.threshold_1(), + options.threshold_2(), options.aperture_size(), + options.l2_gradient()); + cv::normalize(edges_mat, edges_mat, 255, 0, cv::NORM_MINMAX); + edges_mat.convertTo(edges_mat, CV_8UC3, 1, 0); + cv::cvtColor(edges_mat, edges_mat, cv::COLOR_GRAY2RGB); + // Acquires the cv::Mat data and assign to the image frame. + ImageFrameSharedPtr edges_image_frame_ptr = std::make_shared( + mediapipe::ImageFormat::SRGB, edges_mat.cols, edges_mat.rows, + edges_mat.step, edges_mat.data, + [edges_mat](uint8_t[]) { edges_mat.~Mat(); }); + Image edges_image(edges_image_frame_ptr); + kImageOut(cc).Send(edges_image); + return absl::OkStatus(); + } +}; + +// NOLINTBEGIN: Node registration doesn't work when part of calculator name is +// moved to next line. +// clang-format off +MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::CannyEdgeCalculator); +// clang-format on +// NOLINTEND + +} // namespace internal + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; + +constexpr absl::string_view kImageTag = "IMAGE"; +constexpr absl::string_view kUImageTag = "UIMAGE"; +constexpr absl::string_view kNormLandmarksTag = "NORM_LANDMARKS"; +constexpr absl::string_view kVectorTag = "VECTOR"; +constexpr absl::string_view kItemTag = "ITEM"; +constexpr absl::string_view kRenderDataTag = "RENDER_DATA"; +constexpr absl::string_view kConfidenceMaskTag = "CONFIDENCE_MASK:0"; + +enum ColorType { + WHITE = 0, + GREEN = 1, + RED = 2, + BLACK = 3, + BLUE = 4, +}; + +mediapipe::Color GetColor(ColorType color_type) { + mediapipe::Color color; + switch (color_type) { + case WHITE: + color.set_b(255); + color.set_g(255); + color.set_r(255); + break; + case GREEN: + color.set_b(0); + color.set_g(255); + color.set_r(0); + break; + case RED: + color.set_b(0); + color.set_g(0); + color.set_r(255); + break; + case BLACK: + color.set_b(0); + color.set_g(0); + color.set_r(0); + break; + case BLUE: + color.set_b(255); + color.set_g(0); + color.set_r(0); + break; + } + return color; +} + +// Get LandmarksToRenderDataCalculatorOptions for rendering face landmarks +// connections. +mediapipe::LandmarksToRenderDataCalculatorOptions +GetFaceLandmarksRenderDataOptions( + absl::Span> connections, ColorType color_type) { + mediapipe::LandmarksToRenderDataCalculatorOptions render_options; + render_options.set_thickness(1); + render_options.set_visualize_landmark_depth(false); + render_options.set_render_landmarks(false); + *render_options.mutable_connection_color() = GetColor(color_type); + for (const auto& connection : connections) { + render_options.add_landmark_connections(connection[0]); + render_options.add_landmark_connections(connection[1]); + } + return render_options; +} + +Source GetFaceLandmarksRenderData( + Source face_landmarks, + const mediapipe::LandmarksToRenderDataCalculatorOptions& + landmarks_to_render_data_options, + Graph& graph) { + auto& landmarks_to_render_data = + graph.AddNode("LandmarksToRenderDataCalculator"); + landmarks_to_render_data + .GetOptions() + .CopyFrom(landmarks_to_render_data_options); + face_landmarks >> landmarks_to_render_data.In(kNormLandmarksTag); + return landmarks_to_render_data.Out(kRenderDataTag) + .Cast(); +} + +// Add FaceLandmarkerGraph to detect the face landmarks in the given face image, +// and generate a face mesh guidance image for the diffusion plugin model. +absl::StatusOr> GetFaceLandmarksImage( + Source face_image, + const proto::ConditionedImageGraphOptions::FaceConditionTypeOptions& + face_condition_type_options, + Graph& graph) { + if (face_condition_type_options.face_landmarker_graph_options() + .face_detector_graph_options() + .num_faces() != 1) { + return absl::InvalidArgumentError( + "Only supports face landmarks of a single face as the guidance image."); + } + + // Detect face landmarks. + auto& face_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"); + face_landmarker_graph + .GetOptions() + .CopyFrom(face_condition_type_options.face_landmarker_graph_options()); + face_image >> face_landmarker_graph.In(kImageTag); + auto face_landmarks_lists = + face_landmarker_graph.Out(kNormLandmarksTag) + .Cast>(); + + // Get the single face landmarks. + auto& get_vector_item = + graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); + get_vector_item.GetOptions() + .set_item_index(0); + face_landmarks_lists >> get_vector_item.In(kVectorTag); + auto single_face_landmarks = + get_vector_item.Out(kItemTag).Cast(); + + // Convert face landmarks to render data. + auto face_oval = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval + .size()), + ColorType::WHITE), + graph); + auto lips = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips + .size()), + ColorType::WHITE), + graph); + auto left_eye = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye + .size()), + ColorType::GREEN), + graph); + auto left_eye_brow = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksLeftEyeBrow.data(), + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksLeftEyeBrow.size()), + ColorType::GREEN), + graph); + auto left_iris = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris + .size()), + ColorType::GREEN), + graph); + + auto right_eye = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye + .size()), + ColorType::BLUE), + graph); + auto right_eye_brow = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksRightEyeBrow.data(), + face_landmarker::FaceLandmarksConnections:: + kFaceLandmarksRightEyeBrow.size()), + ColorType::BLUE), + graph); + auto right_iris = GetFaceLandmarksRenderData( + single_face_landmarks, + GetFaceLandmarksRenderDataOptions( + absl::Span>( + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris + .data(), + face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris + .size()), + ColorType::BLUE), + graph); + + // Create a black canvas image with same size as face image. + auto& flat_color = graph.AddNode("FlatColorImageCalculator"); + flat_color.GetOptions() + .mutable_color() + ->set_r(0); + face_image >> flat_color.In(kImageTag); + auto blank_canvas = flat_color.Out(kImageTag); + + // Draw render data on the canvas image. + auto& annotation_overlay = graph.AddNode("AnnotationOverlayCalculator"); + blank_canvas >> annotation_overlay.In(kUImageTag); + face_oval >> annotation_overlay.In(0); + lips >> annotation_overlay.In(1); + left_eye >> annotation_overlay.In(2); + left_eye_brow >> annotation_overlay.In(3); + left_iris >> annotation_overlay.In(4); + right_eye >> annotation_overlay.In(5); + right_eye_brow >> annotation_overlay.In(6); + right_iris >> annotation_overlay.In(7); + return annotation_overlay.Out(kUImageTag).Cast(); +} + +absl::StatusOr> GetDepthImage( + Source image, + const image_generator::proto::ConditionedImageGraphOptions:: + DepthConditionTypeOptions& depth_condition_type_options, + Graph& graph) { + auto& image_segmenter_graph = graph.AddNode( + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"); + image_segmenter_graph + .GetOptions() + .CopyFrom(depth_condition_type_options.image_segmenter_graph_options()); + image >> image_segmenter_graph.In(kImageTag); + auto raw_depth_image = image_segmenter_graph.Out(kConfidenceMaskTag); + + auto& depth_postprocessing = graph.AddNode( + "mediapipe.tasks.vision.image_generator.internal." + "DepthImagePostprocessingCalculator"); + raw_depth_image >> depth_postprocessing.In(kImageTag); + return depth_postprocessing.Out(kImageTag).Cast(); +} + +absl::StatusOr> GetEdgeImage( + Source image, + const image_generator::proto::ConditionedImageGraphOptions:: + EdgeConditionTypeOptions& edge_condition_type_options, + Graph& graph) { + auto& edge_detector = graph.AddNode( + "mediapipe.tasks.vision.image_generator.internal." + "CannyEdgeCalculator"); + edge_detector + .GetOptions< + proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>() + .CopyFrom(edge_condition_type_options); + image >> edge_detector.In(kImageTag); + return edge_detector.Out(kImageTag).Cast(); +} + +} // namespace + +// A mediapipe.tasks.vision.image_generator.ConditionedImageGraph converts the +// input image to an image of condition type. The output image can be used as +// input for the diffusion model with control plugin. +// Inputs: +// IMAGE - Image +// Conditioned image to generate the image for diffusion plugin model. +// +// Outputs: +// IMAGE - Image +// The guidance image used as input for the diffusion plugin model. +class ConditionedImageGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto& graph_options = + *sc->MutableOptions(); + Source conditioned_image = graph.In(kImageTag).Cast(); + // Configure the guidance graph and get the guidance image if guidance graph + // options is set. + switch (graph_options.condition_type_options_case()) { + case proto::ConditionedImageGraphOptions::CONDITION_TYPE_OPTIONS_NOT_SET: + return absl::InvalidArgumentError( + "Conditioned type options is not set."); + break; + case proto::ConditionedImageGraphOptions::kFaceConditionTypeOptions: { + ASSIGN_OR_RETURN( + auto face_landmarks_image, + GetFaceLandmarksImage(conditioned_image, + graph_options.face_condition_type_options(), + graph)); + face_landmarks_image >> graph.Out(kImageTag); + } break; + case proto::ConditionedImageGraphOptions::kDepthConditionTypeOptions: { + ASSIGN_OR_RETURN( + auto depth_image, + GetDepthImage(conditioned_image, + graph_options.depth_condition_type_options(), graph)); + depth_image >> graph.Out(kImageTag); + } break; + case proto::ConditionedImageGraphOptions::kEdgeConditionTypeOptions: { + ASSIGN_OR_RETURN( + auto edges_image, + GetEdgeImage(conditioned_image, + graph_options.edge_condition_type_options(), graph)); + edges_image >> graph.Out(kImageTag); + } break; + } + return graph.GetConfig(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ConditionedImageGraph); + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc new file mode 100644 index 000000000..c67ae2fe9 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/conditioned_image_graph_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +namespace { + +using ::mediapipe::Image; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kFaceLandmarkerModel[] = "face_landmarker_v2.task"; +constexpr char kDepthModel[] = + "mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt.tflite"; +constexpr char kPortraitImage[] = "portrait.jpg"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStream[] = "image_in"; +constexpr char kImageOutStream[] = "image_out"; + +// Helper function to create a ConditionedImageGraphTaskRunner TaskRunner. +absl::StatusOr> +CreateConditionedImageGraphTaskRunner( + std::unique_ptr options) { + Graph graph; + auto& conditioned_image_graph = graph.AddNode( + "mediapipe.tasks.vision.image_generator.ConditionedImageGraph"); + conditioned_image_graph.GetOptions() + .Swap(options.get()); + graph.In(kImageTag).Cast().SetName(kImageInStream) >> + conditioned_image_graph.In(kImageTag); + conditioned_image_graph.Out(kImageTag).SetName(kImageOutStream) >> + graph.Out(kImageTag).Cast(); + return core::TaskRunner::Create( + graph.GetConfig(), + absl::make_unique()); +} + +TEST(ConditionedImageGraphTest, SucceedsFaceLandmarkerConditionType) { + auto options = std::make_unique(); + options->mutable_face_condition_type_options() + ->mutable_face_landmarker_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name( + file::JoinPath("./", kTestDataDirectory, kFaceLandmarkerModel)); + options->mutable_face_condition_type_options() + ->mutable_face_landmarker_graph_options() + ->mutable_face_detector_graph_options() + ->set_num_faces(1); + MP_ASSERT_OK_AND_ASSIGN( + auto runner, CreateConditionedImageGraphTaskRunner(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory, + kPortraitImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + runner->Process({{kImageInStream, MakePacket(std::move(image))}})); + const auto& output_image = output_packets[kImageOutStream].Get(); + MP_EXPECT_OK(SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), + "face_landmarks_image")); +} + +TEST(ConditionedImageGraphTest, SucceedsDepthConditionType) { + auto options = std::make_unique(); + options->mutable_depth_condition_type_options() + ->mutable_image_segmenter_graph_options() + ->mutable_base_options() + ->mutable_model_asset() + ->set_file_name(file::JoinPath("./", kTestDataDirectory, kDepthModel)); + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory, + kPortraitImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto runner, CreateConditionedImageGraphTaskRunner(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + runner->Process({{kImageInStream, MakePacket(std::move(image))}})); + const auto& output_image = output_packets[kImageOutStream].Get(); + MP_EXPECT_OK( + SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "depth_image")); +} + +TEST(ConditionedImageGraphTest, SucceedsEdgeConditionType) { + auto options = std::make_unique(); + auto edge_condition_type_options = + options->mutable_edge_condition_type_options(); + edge_condition_type_options->set_threshold_1(100); + edge_condition_type_options->set_threshold_2(200); + edge_condition_type_options->set_aperture_size(3); + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory, + kPortraitImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto runner, CreateConditionedImageGraphTaskRunner(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + auto output_packets, + runner->Process({{kImageInStream, MakePacket(std::move(image))}})); + const auto& output_image = output_packets[kImageOutStream].Get(); + MP_EXPECT_OK( + SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "edges_image")); +} + +} // namespace +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD b/mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD new file mode 100644 index 000000000..e4fd9b5bc --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD @@ -0,0 +1,70 @@ +# Copyright 2022 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +cc_library( + name = "diffuser_gpu_header", + hdrs = ["diffuser_gpu.h"], + visibility = [ + "//mediapipe/tasks/cc/vision/image_generator/diffuser:__pkg__", + ], +) + +mediapipe_proto_library( + name = "stable_diffusion_iterate_calculator_proto", + srcs = ["stable_diffusion_iterate_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "stable_diffusion_iterate_calculator", + srcs = ["stable_diffusion_iterate_calculator.cc"], + deps = [ + ":diffuser_gpu_header", + ":stable_diffusion_iterate_calculator_cc_proto", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/deps:file_helpers", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + ], + alwayslink = 1, +) + +cc_library( + name = "diffusion_plugins_output_calculator", + srcs = ["diffusion_plugins_output_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h new file mode 100644 index 000000000..522f0430c --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h @@ -0,0 +1,88 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_ + +#include +#include + +#ifndef DG_EXPORT +#define DG_EXPORT __attribute__((visibility("default"))) +#endif // DG_EXPORT + +#ifdef __cplusplus +extern "C" { +#endif + +enum DiffuserModelType { + kDiffuserModelTypeSd1, + kDiffuserModelTypeGldm, + kDiffuserModelTypeDistilledGldm, + kDiffuserModelTypeSd2Base, + kDiffuserModelTypeTigo, +}; + +enum DiffuserPriorityHint { + kDiffuserPriorityHintHigh, + kDiffuserPriorityHintNormal, + kDiffuserPriorityHintLow, +}; + +enum DiffuserPerformanceHint { + kDiffuserPerformanceHintHigh, + kDiffuserPerformanceHintNormal, + kDiffuserPerformanceHintLow, +}; + +typedef struct { + DiffuserPriorityHint priority_hint; + DiffuserPerformanceHint performance_hint; +} DiffuserEnvironmentOptions; + +typedef struct { + DiffuserModelType model_type; + char model_dir[PATH_MAX]; + char lora_dir[PATH_MAX]; + const void* lora_weights_layer_mapping; + int lora_rank; + int seed; + int image_width; + int image_height; + int run_unet_with_plugins; + float plugins_strength; + DiffuserEnvironmentOptions env_options; +} DiffuserConfig; + +typedef struct { + void* diffuser; +} DiffuserContext; + +typedef struct { + int shape[4]; + const float* data; +} DiffuserPluginTensor; + +DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT +DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT + const char*, int, int, const void*); +DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT +DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT +DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_ diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc new file mode 100644 index 000000000..98fefe8c5 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffusion_plugins_output_calculator.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { +namespace api2 { + +// In iteration mode, output the image guidance tensors at the current timestamp +// and advance the output stream timestamp bound by the number of steps. +// Otherwise, output the image guidance tensors at the current timestamp only. +class DiffusionPluginsOutputCalculator : public Node { + public: + static constexpr Input> kTensorsIn{"TENSORS"}; + static constexpr Input kStepsIn{"STEPS"}; + static constexpr Input::Optional kIterationIn{"ITERATION"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kStepsIn, kIterationIn, kTensorsOut); + + absl::Status Process(CalculatorContext* cc) override { + if (kTensorsIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + // Consumes the tensor vector to avoid data copy. + absl::StatusOr>> status_or_tensor = + cc->Inputs().Tag("TENSORS").Value().Consume>(); + if (!status_or_tensor.ok()) { + return absl::InternalError("Input tensor vector is not consumable."); + } + if (kIterationIn(cc).IsConnected()) { + CHECK_EQ(kIterationIn(cc).Get(), 0); + kTensorsOut(cc).Send(std::move(*status_or_tensor.value())); + kTensorsOut(cc).SetNextTimestampBound(cc->InputTimestamp() + + kStepsIn(cc).Get()); + } else { + kTensorsOut(cc).Send(std::move(*status_or_tensor.value())); + } + return absl::OkStatus(); + } +}; + +MEDIAPIPE_REGISTER_NODE(DiffusionPluginsOutputCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc new file mode 100644 index 000000000..77b24a715 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc @@ -0,0 +1,278 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_helpers.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +DiffuserPriorityHint ToDiffuserPriorityHint( + StableDiffusionIterateCalculatorOptions::ClPriorityHint priority) { + switch (priority) { + case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_LOW: + return kDiffuserPriorityHintLow; + case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_NORMAL: + return kDiffuserPriorityHintNormal; + case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_HIGH: + return kDiffuserPriorityHintHigh; + } + return kDiffuserPriorityHintNormal; +} + +DiffuserModelType ToDiffuserModelType( + StableDiffusionIterateCalculatorOptions::ModelType model_type) { + switch (model_type) { + case StableDiffusionIterateCalculatorOptions::DEFAULT: + case StableDiffusionIterateCalculatorOptions::SD_1: + return kDiffuserModelTypeSd1; + } + return kDiffuserModelTypeSd1; +} + +} // namespace + +// Runs diffusion models including, but not limited to, Stable Diffusion & gLDM. +// +// Inputs: +// PROMPT - std::string +// The prompt used to generate the image. +// STEPS - int +// The number of steps to run the UNet. +// ITERATION - int +// The iteration of the current run. +// PLUGIN_TENSORS - std::vector @Optional +// The output tensor vector of the diffusion plugins model. +// +// Outputs: +// IMAGE - mediapipe::ImageFrame +// The image generated by the Stable Diffusion model from the input prompt. +// The output image is in RGB format. +// +// Example: +// node { +// calculator: "StableDiffusionIterateCalculator" +// input_stream: "PROMPT:prompt" +// input_stream: "STEPS:steps" +// output_stream: "IMAGE:result" +// options { +// [mediapipe.StableDiffusionIterateCalculatorOptions.ext] { +// base_seed: 0 +// model_type: SD_1 +// } +// } +// } +class StableDiffusionIterateCalculator : public Node { + public: + static constexpr Input kPromptIn{"PROMPT"}; + static constexpr Input kStepsIn{"STEPS"}; + static constexpr Input::Optional kIterationIn{"ITERATION"}; + static constexpr Input::Optional kRandSeedIn{"RAND_SEED"}; + static constexpr SideInput::Optional + kOptionsIn{"OPTIONS"}; + static constexpr Input>::Optional kPlugInTensorsIn{ + "PLUGIN_TENSORS"}; + static constexpr Output kImageOut{"IMAGE"}; + MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn, + kPlugInTensorsIn, kOptionsIn, kImageOut); + + ~StableDiffusionIterateCalculator() { + if (context_) DiffuserDelete(); + if (handle_) dlclose(handle_); + } + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::vector GetPluginTensors( + CalculatorContext* cc) const { + if (!kPlugInTensorsIn(cc).IsConnected()) return {}; + std::vector diffuser_tensors; + diffuser_tensors.reserve(kPlugInTensorsIn(cc)->size()); + for (const auto& mp_tensor : *kPlugInTensorsIn(cc)) { + DiffuserPluginTensor diffuser_tensor; + diffuser_tensor.shape[0] = mp_tensor.shape().dims[0]; + diffuser_tensor.shape[1] = mp_tensor.shape().dims[1]; + diffuser_tensor.shape[2] = mp_tensor.shape().dims[2]; + diffuser_tensor.shape[3] = mp_tensor.shape().dims[3]; + diffuser_tensor.data = mp_tensor.GetCpuReadView().buffer(); + diffuser_tensors.push_back(diffuser_tensor); + } + return diffuser_tensors; + } + + absl::Status LoadDiffuser() { + handle_ = dlopen("libimagegenerator_gpu.so", RTLD_NOW | RTLD_LOCAL); + RET_CHECK(handle_) << dlerror(); + create_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserCreate")); + RET_CHECK(create_ptr_) << dlerror(); + reset_ptr_ = + reinterpret_cast(dlsym(handle_, "DiffuserReset")); + RET_CHECK(reset_ptr_) << dlerror(); + iterate_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserIterate")); + RET_CHECK(iterate_ptr_) << dlerror(); + decode_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserDecode")); + RET_CHECK(decode_ptr_) << dlerror(); + delete_ptr_ = reinterpret_cast( + dlsym(handle_, "DiffuserDelete")); + RET_CHECK(delete_ptr_) << dlerror(); + return absl::OkStatus(); + } + + DiffuserContext* DiffuserCreate(const DiffuserConfig* a) { + return (*create_ptr_)(a); + } + bool DiffuserReset(const char* a, int b, int c, + const std::vector* d) { + return (*reset_ptr_)(context_, a, b, c, d); + } + bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); } + bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); } + void DiffuserDelete() { (*delete_ptr_)(context_); } + + void* handle_ = nullptr; + DiffuserContext* context_ = nullptr; + DiffuserContext* (*create_ptr_)(const DiffuserConfig*); + int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*); + int (*iterate_ptr_)(DiffuserContext*, int, int); + int (*decode_ptr_)(DiffuserContext*, uint8_t*); + void (*delete_ptr_)(DiffuserContext*); + + int show_every_n_iteration_; + bool emit_empty_packet_; +}; + +absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) { + StableDiffusionIterateCalculatorOptions options; + if (kOptionsIn(cc).IsEmpty()) { + options = cc->Options(); + } else { + options = kOptionsIn(cc).Get(); + } + show_every_n_iteration_ = options.show_every_n_iteration(); + emit_empty_packet_ = options.emit_empty_packet(); + + MP_RETURN_IF_ERROR(LoadDiffuser()); + + DiffuserConfig config; + config.model_type = ToDiffuserModelType(options.model_type()); + if (options.file_folder().empty()) { + std::strcpy(config.model_dir, "bins/"); // NOLINT + } else { + std::strcpy(config.model_dir, options.file_folder().c_str()); // NOLINT + } + MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir)) + << config.model_dir; + RET_CHECK(options.lora_file_folder().empty() || + options.lora_weights_layer_mapping().empty()) + << "Can't set both lora_file_folder and lora_weights_layer_mapping."; + std::strcpy(config.lora_dir, options.lora_file_folder().c_str()); // NOLINT + std::map lora_weights_layer_mapping; + for (auto& layer_name_and_weights : options.lora_weights_layer_mapping()) { + lora_weights_layer_mapping[layer_name_and_weights.first] = + (char*)layer_name_and_weights.second; + } + config.lora_weights_layer_mapping = !lora_weights_layer_mapping.empty() + ? &lora_weights_layer_mapping + : nullptr; + config.lora_rank = options.lora_rank(); + config.seed = options.base_seed(); + config.image_width = options.output_image_width(); + config.image_height = options.output_image_height(); + config.run_unet_with_plugins = kPlugInTensorsIn(cc).IsConnected(); + config.env_options = { + .priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()), + .performance_hint = kDiffuserPerformanceHintHigh, + }; + config.plugins_strength = options.plugins_strength(); + RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f) + << "The value of plugins_strength must be in the range of [0, 1]."; + context_ = DiffuserCreate(&config); + RET_CHECK(context_); + return absl::OkStatus(); +} + +absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) { + const auto& options = + cc->Options().GetExtension(StableDiffusionIterateCalculatorOptions::ext); + const std::string& prompt = *kPromptIn(cc); + const int steps = *kStepsIn(cc); + const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc)) + : options.base_seed(); + + if (kIterationIn(cc).IsEmpty()) { + const auto plugin_tensors = GetPluginTensors(cc); + RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors)); + for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i)); + ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(), + options.output_image_height()); + RET_CHECK(DiffuserDecode(image_out.MutablePixelData())); + kImageOut(cc).Send(std::move(image_out)); + } else { + const int iteration = *kIterationIn(cc); + RET_CHECK_LT(iteration, steps); + + // Extract text embedding on first iteration. + if (iteration == 0) { + const auto plugin_tensors = GetPluginTensors(cc); + RET_CHECK( + DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors)); + } + + RET_CHECK(DiffuserIterate(steps, iteration)); + + // Decode the output and send out the image for visualization. + if ((iteration + 1) % show_every_n_iteration_ == 0 || + iteration == steps - 1) { + ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(), + options.output_image_height()); + RET_CHECK(DiffuserDecode(image_out.MutablePixelData())); + kImageOut(cc).Send(std::move(image_out)); + } else if (emit_empty_packet_) { + kImageOut(cc).Send(Packet()); + } + } + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(StableDiffusionIterateCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto new file mode 100644 index 000000000..ce6dcefd0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.proto @@ -0,0 +1,84 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "StableDiffusionIterateCalculatorOptionsProto"; + +message StableDiffusionIterateCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional StableDiffusionIterateCalculatorOptions ext = 510855836; + } + + // The random seed that is fed into the calculator to control the randomness + // of the generated image. + optional uint32 base_seed = 1 [default = 0]; + + // The target output image size. Must be a multiple of 8 and larger than 384. + optional int32 output_image_width = 2 [default = 512]; + optional int32 output_image_height = 3 [default = 512]; + + // The folder name must end of '/'. + optional string file_folder = 4 [default = "bins/"]; + + // Note: only one of lora_file_folder and lora_weights_layer_mapping should be + // set. + // The LoRA file folder. The folder name must end of '/'. + optional string lora_file_folder = 9 [default = ""]; + + // The LoRA layer name mapping to the weight buffer position in the file. + map lora_weights_layer_mapping = 10; + + // The LoRA rank. + optional int32 lora_rank = 12 [default = 4]; + + // Determine when to run image decoding for how many every iterations. + // Setting this to 1 means we run the image decoding for every iteration for + // displaying the intermediate result, but it will also introduce much higher + // overall latency. + // Setting this to be the targeted number of iterations will only run the + // image decoding at the end, giving the best overall latency. + optional int32 show_every_n_iteration = 5 [default = 1]; + + // If set to be True, the calculator will perform a GPU-CPU sync and emit an + // empty packet. It is used to provide the signal of which iterations it is + // currently at, typically used to create a progress bar. Note that this also + // introduce overhead, but not significanly based on our experiments (~1ms). + optional bool emit_empty_packet = 6 [default = false]; + + enum ClPriorityHint { + PRIORITY_HINT_NORMAL = 0; // Default, must be first. + PRIORITY_HINT_LOW = 1; + PRIORITY_HINT_HIGH = 2; + } + + // OpenCL priority hint. Set this to LOW to yield to other GPU contexts. + // This lowers inference speed, but helps keeping the UI responsive. + optional ClPriorityHint cl_priority_hint = 7; + + enum ModelType { + DEFAULT = 0; + SD_1 = 1; // Stable Diffusion v1 models, including SD 1.4 and 1.5. + } + // Stable Diffusion model type. Default to Stable Diffusion v1. + optional ModelType model_type = 8 [default = SD_1]; + // The strength of the diffusion plugins inputs. + optional float plugins_strength = 11 [default = 1.0]; +} diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator.cc b/mediapipe/tasks/cc/vision/image_generator/image_generator.cc new file mode 100644 index 000000000..e4464d84d --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator.cc @@ -0,0 +1,397 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_generator/image_generator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { +namespace { + +using ImageGeneratorGraphOptionsProto = ::mediapipe::tasks::vision:: + image_generator::proto::ImageGeneratorGraphOptions; +using ConditionedImageGraphOptionsProto = ::mediapipe::tasks::vision:: + image_generator::proto::ConditionedImageGraphOptions; +using ControlPluginGraphOptionsProto = ::mediapipe::tasks::vision:: + image_generator::proto::ControlPluginGraphOptions; +using FaceLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: + face_landmarker::proto::FaceLandmarkerGraphOptions; + +constexpr absl::string_view kImageTag = "IMAGE"; +constexpr absl::string_view kImageOutName = "image_out"; +constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE"; +constexpr absl::string_view kConditionImageName = "condition_image"; +constexpr absl::string_view kSourceConditionImageName = + "source_condition_image"; +constexpr absl::string_view kStepsTag = "STEPS"; +constexpr absl::string_view kStepsName = "steps"; +constexpr absl::string_view kIterationTag = "ITERATION"; +constexpr absl::string_view kIterationName = "iteration"; +constexpr absl::string_view kPromptTag = "PROMPT"; +constexpr absl::string_view kPromptName = "prompt"; +constexpr absl::string_view kRandSeedTag = "RAND_SEED"; +constexpr absl::string_view kRandSeedName = "rand_seed"; +constexpr absl::string_view kSelectTag = "SELECT"; +constexpr absl::string_view kSelectName = "select"; + +constexpr char kImageGeneratorGraphTypeName[] = + "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph"; + +constexpr char kConditionedImageGraphContainerTypeName[] = + "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer"; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph". +CalculatorGraphConfig CreateImageGeneratorGraphConfig( + std::unique_ptr options, + bool use_condition_image) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kImageGeneratorGraphTypeName); + subgraph.GetOptions().CopyFrom(*options); + graph.In(kStepsTag).SetName(kStepsName) >> subgraph.In(kStepsTag); + graph.In(kIterationTag).SetName(kIterationName) >> subgraph.In(kIterationTag); + graph.In(kPromptTag).SetName(kPromptName) >> subgraph.In(kPromptTag); + graph.In(kRandSeedTag).SetName(kRandSeedName) >> subgraph.In(kRandSeedTag); + if (use_condition_image) { + graph.In(kConditionImageTag).SetName(kConditionImageName) >> + subgraph.In(kConditionImageTag); + graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag); + } + subgraph.Out(kImageTag).SetName(kImageOutName) >> + graph[api2::Output::Optional(kImageTag)]; + return graph.GetConfig(); +} + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer". +CalculatorGraphConfig CreateConditionedImageGraphContainerConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kConditionedImageGraphContainerTypeName); + subgraph.GetOptions().CopyFrom(*options); + graph.In(kImageTag).SetName(kSourceConditionImageName) >> + subgraph.In(kImageTag); + graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag); + subgraph.Out(kConditionImageTag).SetName(kConditionImageName) >> + graph.Out(kConditionImageTag).Cast(); + return graph.GetConfig(); +} + +absl::Status SetFaceConditionOptionsToProto( + FaceConditionOptions& face_condition_options, + ControlPluginGraphOptionsProto& options_proto) { + // Configure face plugin model. + auto plugin_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(face_condition_options.base_options))); + options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get()); + + // Configure face landmarker graph. + auto& face_landmarker_options = + face_condition_options.face_landmarker_options; + auto& face_landmarker_options_proto = + *options_proto.mutable_conditioned_image_graph_options() + ->mutable_face_condition_type_options() + ->mutable_face_landmarker_graph_options(); + + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(face_landmarker_options.base_options))); + face_landmarker_options_proto.mutable_base_options()->Swap( + base_options_proto.get()); + face_landmarker_options_proto.mutable_base_options()->set_use_stream_mode( + false); + + // Configure face detector options. + auto* face_detector_graph_options = + face_landmarker_options_proto.mutable_face_detector_graph_options(); + face_detector_graph_options->set_num_faces(face_landmarker_options.num_faces); + face_detector_graph_options->set_min_detection_confidence( + face_landmarker_options.min_face_detection_confidence); + + // Configure face landmark detector options. + face_landmarker_options_proto.set_min_tracking_confidence( + face_landmarker_options.min_tracking_confidence); + auto* face_landmarks_detector_graph_options = + face_landmarker_options_proto + .mutable_face_landmarks_detector_graph_options(); + face_landmarks_detector_graph_options->set_min_detection_confidence( + face_landmarker_options.min_face_presence_confidence); + return absl::OkStatus(); +} + +absl::Status SetDepthConditionOptionsToProto( + DepthConditionOptions& depth_condition_options, + ControlPluginGraphOptionsProto& options_proto) { + // Configure face plugin model. + auto plugin_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(depth_condition_options.base_options))); + options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get()); + + auto& image_segmenter_graph_options = + *options_proto.mutable_conditioned_image_graph_options() + ->mutable_depth_condition_type_options() + ->mutable_image_segmenter_graph_options(); + + auto depth_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(depth_condition_options.image_segmenter_options.base_options))); + image_segmenter_graph_options.mutable_base_options()->Swap( + depth_base_options_proto.get()); + image_segmenter_graph_options.mutable_base_options()->set_use_stream_mode( + false); + image_segmenter_graph_options.set_display_names_locale( + depth_condition_options.image_segmenter_options.display_names_locale); + return absl::OkStatus(); +} + +absl::Status SetEdgeConditionOptionsToProto( + EdgeConditionOptions& edge_condition_options, + ControlPluginGraphOptionsProto& options_proto) { + auto plugin_base_options_proto = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(edge_condition_options.base_options))); + options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get()); + + auto& edge_options_proto = + *options_proto.mutable_conditioned_image_graph_options() + ->mutable_edge_condition_type_options(); + edge_options_proto.set_threshold_1(edge_condition_options.threshold_1); + edge_options_proto.set_threshold_2(edge_condition_options.threshold_2); + edge_options_proto.set_aperture_size(edge_condition_options.aperture_size); + edge_options_proto.set_l2_gradient(edge_condition_options.l2_gradient); + return absl::OkStatus(); +} + +// Helper holder struct of image generator graph options and condition type +// index mapping. +struct ImageGeneratorOptionsProtoAndConditionTypeIndex { + std::unique_ptr options_proto; + std::unique_ptr> + condition_type_index; +}; + +// Converts the user-facing ImageGeneratorOptions struct to the internal +// ImageGeneratorOptions proto. +absl::StatusOr +ConvertImageGeneratorGraphOptionsProto( + ImageGeneratorOptions* image_generator_options, + ConditionOptions* condition_options) { + ImageGeneratorOptionsProtoAndConditionTypeIndex + options_proto_and_condition_index; + + // Configure base image generator options. + options_proto_and_condition_index.options_proto = + std::make_unique(); + auto& options_proto = *options_proto_and_condition_index.options_proto; + options_proto.set_text2image_model_directory( + image_generator_options->text2image_model_directory); + if (image_generator_options->lora_weights_file_path.has_value()) { + options_proto.mutable_lora_weights_file()->set_file_name( + *image_generator_options->lora_weights_file_path); + } + + // Configure optional condition type options. + if (condition_options != nullptr) { + options_proto_and_condition_index.condition_type_index = + std::make_unique>(); + auto& condition_type_index = + *options_proto_and_condition_index.condition_type_index; + if (condition_options->face_condition_options.has_value()) { + condition_type_index[ConditionOptions::FACE] = + condition_type_index.size(); + auto& face_plugin_graph_options = + *options_proto.add_control_plugin_graphs_options(); + RET_CHECK_OK(SetFaceConditionOptionsToProto( + *condition_options->face_condition_options, + face_plugin_graph_options)); + } + if (condition_options->depth_condition_options.has_value()) { + condition_type_index[ConditionOptions::DEPTH] = + condition_type_index.size(); + auto& depth_plugin_graph_options = + *options_proto.add_control_plugin_graphs_options(); + RET_CHECK_OK(SetDepthConditionOptionsToProto( + *condition_options->depth_condition_options, + depth_plugin_graph_options)); + } + if (condition_options->edge_condition_options.has_value()) { + condition_type_index[ConditionOptions::EDGE] = + condition_type_index.size(); + auto& edge_plugin_graph_options = + *options_proto.add_control_plugin_graphs_options(); + RET_CHECK_OK(SetEdgeConditionOptionsToProto( + *condition_options->edge_condition_options, + edge_plugin_graph_options)); + } + if (condition_type_index.empty()) { + return absl::InvalidArgumentError( + "At least one condition type must be set."); + } + } + return options_proto_and_condition_index; +} + +} // namespace + +absl::StatusOr> ImageGenerator::Create( + std::unique_ptr image_generator_options, + std::unique_ptr condition_options) { + bool use_condition_image = condition_options != nullptr; + ASSIGN_OR_RETURN(auto options_proto_and_condition_index, + ConvertImageGeneratorGraphOptionsProto( + image_generator_options.get(), condition_options.get())); + std::unique_ptr + options_proto_for_condition_image_graphs_container; + if (use_condition_image) { + options_proto_for_condition_image_graphs_container = + std::make_unique(); + options_proto_for_condition_image_graphs_container->CopyFrom( + *options_proto_and_condition_index.options_proto); + } + ASSIGN_OR_RETURN( + auto image_generator, + (core::VisionTaskApiFactory::Create( + CreateImageGeneratorGraphConfig( + std::move(options_proto_and_condition_index.options_proto), + use_condition_image), + std::make_unique(), + core::RunningMode::IMAGE, + /*result_callback=*/nullptr))); + image_generator->use_condition_image_ = use_condition_image; + if (use_condition_image) { + image_generator->condition_type_index_ = + std::move(options_proto_and_condition_index.condition_type_index); + ASSIGN_OR_RETURN( + image_generator->condition_image_graphs_container_task_runner_, + tasks::core::TaskRunner::Create( + CreateConditionedImageGraphContainerConfig( + std::move(options_proto_for_condition_image_graphs_container)), + absl::make_unique())); + } + image_generator->init_timestamp_ = absl::Now(); + return image_generator; +} + +absl::StatusOr ImageGenerator::CreateConditionImage( + Image source_condition_image, + ConditionOptions::ConditionType condition_type) { + if (condition_type_index_->find(condition_type) == + condition_type_index_->end()) { + return absl::InvalidArgumentError( + "The condition type is not created during initialization."); + } + ASSIGN_OR_RETURN( + auto output_packets, + condition_image_graphs_container_task_runner_->Process({ + {std::string(kSourceConditionImageName), + MakePacket(std::move(source_condition_image))}, + {std::string(kSelectName), + MakePacket(condition_type_index_->at(condition_type))}, + })); + return output_packets.at(std::string(kConditionImageName)).Get(); +} + +absl::StatusOr ImageGenerator::Generate( + const std::string& prompt, int iterations, int seed) { + if (use_condition_image_) { + return absl::InvalidArgumentError( + "ImageGenerator is created to use with conditioned image."); + } + return RunIterations(prompt, iterations, seed, std::nullopt); +} + +absl::StatusOr ImageGenerator::Generate( + const std::string& prompt, Image condition_image, + ConditionOptions::ConditionType condition_type, int iterations, int seed) { + if (!use_condition_image_) { + return absl::InvalidArgumentError( + "ImageGenerator is created to use without conditioned image."); + } + ASSIGN_OR_RETURN(auto plugin_model_image, + CreateConditionImage(condition_image, condition_type)); + return RunIterations( + prompt, iterations, seed, + ConditionInputs{plugin_model_image, + condition_type_index_->at(condition_type)}); +} + +absl::StatusOr ImageGenerator::RunIterations( + const std::string& prompt, int steps, int rand_seed, + std::optional condition_inputs) { + tasks::core::PacketMap output_packets; + ImageGeneratorResult result; + auto timestamp = (absl::Now() - init_timestamp_) / absl::Milliseconds(1); + for (int i = 0; i < steps; ++i) { + tasks::core::PacketMap input_packets; + if (i == 0 && condition_inputs.has_value()) { + input_packets[std::string(kConditionImageName)] = + MakePacket(condition_inputs->condition_image) + .At(Timestamp(timestamp)); + input_packets[std::string(kSelectName)] = + MakePacket(condition_inputs->select).At(Timestamp(timestamp)); + } + input_packets[std::string(kStepsName)] = + MakePacket(steps).At(Timestamp(timestamp)); + input_packets[std::string(kIterationName)] = + MakePacket(i).At(Timestamp(timestamp)); + input_packets[std::string(kPromptName)] = + MakePacket(prompt).At(Timestamp(timestamp)); + input_packets[std::string(kRandSeedName)] = + MakePacket(rand_seed).At(Timestamp(timestamp)); + ASSIGN_OR_RETURN(output_packets, ProcessImageData(input_packets)); + timestamp += 1; + } + result.generated_image = + output_packets.at(std::string(kImageOutName)).Get(); + if (condition_inputs.has_value()) { + result.condition_image = condition_inputs->condition_image; + } + return result; +} + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator.h b/mediapipe/tasks/cc/vision/image_generator/image_generator.h new file mode 100644 index 000000000..52599c02f --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator.h @@ -0,0 +1,157 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h" +#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +// Options for drawing face landmarks image. +struct FaceConditionOptions { + // The base options for plugin model. + tasks::core::BaseOptions base_options; + + // Face landmarker options used to detect face landmarks in the condition + // image. + face_landmarker::FaceLandmarkerOptions face_landmarker_options; +}; + +// Options for detecting edges image. +struct EdgeConditionOptions { + // The base options for plugin model. + tasks::core::BaseOptions base_options; + + // These parameters are used to config Canny edge algorithm of OpenCV. + // See more details: + // https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de + + // First threshold for the hysteresis procedure. + float threshold_1 = 100; + + // Second threshold for the hysteresis procedure. + float threshold_2 = 200; + + // Aperture size for the Sobel operator. Typical range is 3~7. + int aperture_size = 3; + + // A flag, indicating whether a more accurate L2 norm should be used to + // calculate the image gradient magnitude ( L2gradient=true ), or whether + // the default L1 norm is enough ( L2gradient=false ). + bool l2_gradient = false; +}; + +// Options for detecting depth image. +struct DepthConditionOptions { + // The base options for plugin model. + tasks::core::BaseOptions base_options; + + // Image segmenter options used to detect depth in the condition image. + image_segmenter::ImageSegmenterOptions image_segmenter_options; +}; + +struct ConditionOptions { + enum ConditionType { FACE, EDGE, DEPTH }; + std::optional face_condition_options; + std::optional edge_condition_options; + std::optional depth_condition_options; +}; + +// Note: The API is experimental and subject to change. +// The options for configuring a mediapipe image generator task. +struct ImageGeneratorOptions { + // The text to image model directory storing the model weights. + std::string text2image_model_directory; + + // The path to LoRA weights file. + std::optional lora_weights_file_path; +}; + +class ImageGenerator : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageGenerator from the provided options. + // image_generator_options: options to create the image generator. + // condition_options: optional options if plugin models are used to generate + // an image based on the condition image. + static absl::StatusOr> Create( + std::unique_ptr image_generator_options, + std::unique_ptr condition_options = nullptr); + + // Create the condition image of specified condition type from the source + // condition image. Currently support face landmarks, depth image and edge + // image as the condition image. + absl::StatusOr CreateConditionImage( + Image source_condition_image, + ConditionOptions::ConditionType condition_type); + + // Generates an image for iterations and the given random seed. Only valid + // when the ImageGenerator is created without condition options. + absl::StatusOr Generate(const std::string& prompt, + int iterations, int seed = 0); + + // Generates an image based on the condition image for iterations and the + // given random seed. + // A detailed introduction to the condition image: + // https://ai.googleblog.com/2023/06/on-device-diffusion-plugins-for.html + absl::StatusOr Generate( + const std::string& prompt, Image condition_image, + ConditionOptions::ConditionType condition_type, int iterations, + int seed = 0); + + private: + struct ConditionInputs { + Image condition_image; + int select; + }; + + bool use_condition_image_ = false; + + absl::Time init_timestamp_; + + std::unique_ptr + condition_image_graphs_container_task_runner_; + + std::unique_ptr> + condition_type_index_; + + absl::StatusOr RunIterations( + const std::string& prompt, int steps, int rand_seed, + std::optional condition_inputs); +}; + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_ diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc b/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc new file mode 100644 index 000000000..639a73e34 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator_graph.cc @@ -0,0 +1,361 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/tool/switch_container.pb.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h" +#include "mediapipe/util/graph_builder_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; + +constexpr int kPluginsOutputSize = 512; +constexpr absl::string_view kTensorsTag = "TENSORS"; +constexpr absl::string_view kImageTag = "IMAGE"; +constexpr absl::string_view kImageCpuTag = "IMAGE_CPU"; +constexpr absl::string_view kStepsTag = "STEPS"; +constexpr absl::string_view kIterationTag = "ITERATION"; +constexpr absl::string_view kPromptTag = "PROMPT"; +constexpr absl::string_view kRandSeedTag = "RAND_SEED"; +constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS"; +constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE"; +constexpr absl::string_view kSelectTag = "SELECT"; +constexpr absl::string_view kMetadataFilename = "metadata"; +constexpr absl::string_view kLoraRankStr = "lora_rank"; + +struct ImageGeneratorInputs { + Source prompt; + Source steps; + Source iteration; + Source rand_seed; + std::optional> condition_image; + std::optional> select_condition_type; +}; + +struct ImageGeneratorOutputs { + Source generated_image; +}; + +} // namespace + +// A container graph containing several ConditionedImageGraph from which to +// choose specified condition type. +// Inputs: +// IMAGE - Image +// The source condition image, used to generate the condition image. +// SELECT - int +// The index of the selected conditioned image graph. +// Outputs: +// CONDITION_IMAGE - Image +// The condition image created from the specified condition type. +class ConditionedImageGraphContainer : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto& graph_options = + *sc->MutableOptions(); + auto source_condition_image = graph.In(kImageTag).Cast(); + auto select_condition_type = graph.In(kSelectTag).Cast(); + auto& switch_container = graph.AddNode("SwitchContainer"); + auto& switch_options = + switch_container.GetOptions(); + for (auto& control_plugin_graph_options : + *graph_options.mutable_control_plugin_graphs_options()) { + auto& node = *switch_options.add_contained_node(); + node.set_calculator( + "mediapipe.tasks.vision.image_generator.ConditionedImageGraph"); + node.mutable_node_options()->Add()->PackFrom( + control_plugin_graph_options.conditioned_image_graph_options()); + } + source_condition_image >> switch_container.In(kImageTag); + select_condition_type >> switch_container.In(kSelectTag); + auto condition_image = switch_container.Out(kImageTag).Cast(); + condition_image >> graph.Out(kConditionImageTag); + return graph.GetConfig(); + } +}; + +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ConditionedImageGraphContainer); // NOLINT +// clang-format on + +// A helper graph to convert condition image to Tensor using the control plugin +// model. +// Inputs: +// CONDITION_IMAGE - Image +// The condition image input to the control plugin model. +// Outputs: +// PLUGIN_TENSORS - std::vector +// The output tensors from the control plugin model. The tensors are used as +// inputs to the image generation model. +class ControlPluginGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto& graph_options = + *sc->MutableOptions(); + + auto condition_image = graph.In(kConditionImageTag).Cast(); + + // Convert Image to ImageFrame. + auto& from_image = graph.AddNode("FromImageCalculator"); + condition_image >> from_image.In(kImageTag); + auto image_frame = from_image.Out(kImageCpuTag); + + // Convert ImageFrame to Tensor. + auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); + auto& image_to_tensor_options = + image_to_tensor.GetOptions(); + image_to_tensor_options.set_output_tensor_width(kPluginsOutputSize); + image_to_tensor_options.set_output_tensor_height(kPluginsOutputSize); + image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1); + image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_frame >> image_to_tensor.In(kImageTag); + + // Create the plugin model resource. + ASSIGN_OR_RETURN( + const core::ModelResources* plugin_model_resources, + CreateModelResources( + sc, + std::make_unique( + *graph_options.mutable_base_options()->mutable_model_asset()))); + + // Add control plugin model inference. + auto& plugins_inference = + AddInference(*plugin_model_resources, + graph_options.base_options().acceleration(), graph); + image_to_tensor.Out(kTensorsTag) >> plugins_inference.In(kTensorsTag); + // The plugins model is not runnable on OpenGL. Error message: + // TfLiteGpuDelegate Prepare: Batch size mismatch, expected 1 but got 64 + // Node number 67 (TfLiteGpuDelegate) failed to prepare. + plugins_inference.GetOptions() + .mutable_delegate() + ->mutable_xnnpack(); + plugins_inference.Out(kTensorsTag).Cast>() >> + graph.Out(kPluginTensorsTag); + return graph.GetConfig(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ControlPluginGraph); + +// A "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph" performs image +// generation from a text prompt, and a optional condition image. +// +// Inputs: +// PROMPT - std::string +// The prompt describing the image to be generated. +// STEPS - int +// The total steps to generate the image. +// ITERATION - int +// The current iteration in the generating steps. Must be less than STEPS. +// RAND_SEED - int +// The randaom seed input to the image generation model. +// CONDITION_IMAGE - Image +// The condition image used as a guidance for the image generation. Only +// valid, if condtrol plugin graph options are set in the graph options. +// SELECT - int +// The index of the selected the control plugin graph. +// +// Outputs: +// IMAGE - Image +// The generated image. +// STEPS - int @optional +// The total steps to generate the image. The same as STEPS input. +// ITERATION - int @optional +// The current iteration in the generating steps. The same as ITERATION +// input. +class ImageGeneratorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + auto* subgraph_options = + sc->MutableOptions(); + std::optional lora_resources; + // Create LoRA weights asset bundle resources. + if (subgraph_options->has_lora_weights_file()) { + auto external_file = std::make_unique(); + external_file->Swap(subgraph_options->mutable_lora_weights_file()); + ASSIGN_OR_RETURN(lora_resources, CreateModelAssetBundleResources( + sc, std::move(external_file))); + } + std::optional> condition_image; + std::optional> select_condition_type; + if (!subgraph_options->control_plugin_graphs_options().empty()) { + condition_image = graph.In(kConditionImageTag).Cast(); + select_condition_type = graph.In(kSelectTag).Cast(); + } + ASSIGN_OR_RETURN( + auto outputs, + BuildImageGeneratorGraph( + *sc->MutableOptions(), + lora_resources, + ImageGeneratorInputs{ + /*prompt=*/graph.In(kPromptTag).Cast(), + /*steps=*/graph.In(kStepsTag).Cast(), + /*iteration=*/graph.In(kIterationTag).Cast(), + /*rand_seed=*/graph.In(kRandSeedTag).Cast(), + /*condition_image*/ condition_image, + /*select_condition_type*/ select_condition_type, + }, + graph)); + outputs.generated_image >> graph.Out(kImageTag).Cast(); + + // Optional outputs to provide the current iteration. + auto& pass_through = graph.AddNode("PassThroughCalculator"); + graph.In(kIterationTag) >> pass_through.In(0); + graph.In(kStepsTag) >> pass_through.In(1); + pass_through.Out(0) >> graph[Output::Optional(kIterationTag)]; + pass_through.Out(1) >> graph[Output::Optional(kStepsTag)]; + return graph.GetConfig(); + } + + absl::StatusOr BuildImageGeneratorGraph( + proto::ImageGeneratorGraphOptions& subgraph_options, + std::optional lora_resources, + ImageGeneratorInputs inputs, Graph& graph) { + auto& stable_diff = graph.AddNode("StableDiffusionIterateCalculator"); + if (inputs.condition_image.has_value()) { + // Add switch container for multiple control plugin graphs. + auto& switch_container = graph.AddNode("SwitchContainer"); + auto& switch_options = + switch_container.GetOptions(); + for (auto& control_plugin_graph_options : + *subgraph_options.mutable_control_plugin_graphs_options()) { + auto& node = *switch_options.add_contained_node(); + node.set_calculator( + "mediapipe.tasks.vision.image_generator.ControlPluginGraph"); + node.mutable_node_options()->Add()->PackFrom( + control_plugin_graph_options); + } + *inputs.condition_image >> switch_container.In(kConditionImageTag); + *inputs.select_condition_type >> switch_container.In(kSelectTag); + auto plugin_tensors = switch_container.Out(kPluginTensorsTag); + + // Additional diffusion plugins calculator to pass tensors to diffusion + // iterator. + auto& plugins_output = graph.AddNode("DiffusionPluginsOutputCalculator"); + plugin_tensors >> plugins_output.In(kTensorsTag); + inputs.steps >> plugins_output.In(kStepsTag); + inputs.iteration >> plugins_output.In(kIterationTag); + plugins_output.Out(kTensorsTag) >> stable_diff.In(kPluginTensorsTag); + } + + inputs.prompt >> stable_diff.In(kPromptTag); + inputs.steps >> stable_diff.In(kStepsTag); + inputs.iteration >> stable_diff.In(kIterationTag); + inputs.rand_seed >> stable_diff.In(kRandSeedTag); + mediapipe::StableDiffusionIterateCalculatorOptions& options = + stable_diff + .GetOptions(); + options.set_base_seed(0); + options.set_output_image_height(kPluginsOutputSize); + options.set_output_image_width(kPluginsOutputSize); + options.set_file_folder(subgraph_options.text2image_model_directory()); + options.set_show_every_n_iteration(100); + options.set_emit_empty_packet(true); + if (lora_resources.has_value()) { + auto& lora_layer_weights_mapping = + *options.mutable_lora_weights_layer_mapping(); + for (const auto& file_path : (*lora_resources)->ListFiles()) { + auto basename = file::Basename(file_path); + ASSIGN_OR_RETURN(auto file_content, + (*lora_resources)->GetFile(std::string(file_path))); + if (file_path == kMetadataFilename) { + MP_RETURN_IF_ERROR( + ParseLoraMetadataAndConfigOptions(file_content, options)); + } else { + lora_layer_weights_mapping[basename] = + reinterpret_cast(file_content.data()); + } + } + } + + auto& to_image = graph.AddNode("ToImageCalculator"); + stable_diff.Out(kImageTag) >> to_image.In(kImageCpuTag); + + return {{to_image.Out(kImageTag).Cast()}}; + } + + private: + absl::Status ParseLoraMetadataAndConfigOptions( + absl::string_view contents, + mediapipe::StableDiffusionIterateCalculatorOptions& options) { + std::vector lines = + absl::StrSplit(contents, '\n', absl::SkipEmpty()); + for (const auto& line : lines) { + std::vector values = absl::StrSplit(line, ','); + if (values[0] == kLoraRankStr) { + int lora_rank; + if (values.size() != 2 || !absl::SimpleAtoi(values[1], &lora_rank)) { + return absl::InvalidArgumentError( + absl::StrCat("Error parsing LoRA weights metadata. ", line)); + } + options.set_lora_rank(lora_rank); + } + } + return absl::OkStatus(); + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_generator::ImageGeneratorGraph); + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_generator/image_generator_result.h b/mediapipe/tasks/cc/vision/image_generator/image_generator_result.h new file mode 100644 index 000000000..7b7054d74 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/image_generator_result.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_ + +#include "mediapipe/framework/formats/image.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace image_generator { + +// The result of ImageGenerator task. +struct ImageGeneratorResult { + // The generated image. + Image generated_image; + + // The condition_image used in the plugin model, only available if the + // condition type is set in ImageGeneratorOptions. + std::optional condition_image = std::nullopt; +}; + +} // namespace image_generator +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_ diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/BUILD b/mediapipe/tasks/cc/vision/image_generator/proto/BUILD new file mode 100644 index 000000000..38e1048cf --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/BUILD @@ -0,0 +1,52 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "conditioned_image_graph_options_proto", + srcs = ["conditioned_image_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_proto", + ], +) + +mediapipe_proto_library( + name = "control_plugin_graph_options_proto", + srcs = ["control_plugin_graph_options.proto"], + deps = [ + ":conditioned_image_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "image_generator_graph_options_proto", + srcs = ["image_generator_graph_options.proto"], + deps = [ + ":control_plugin_graph_options_proto", + "//mediapipe/tasks/cc/core/proto:external_file_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto new file mode 100644 index 000000000..8d0798d76 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto @@ -0,0 +1,66 @@ + +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +syntax = "proto3"; + +package mediapipe.tasks.vision.image_generator.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto"; +import "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; +option java_outer_classname = "ConditionedImageGraphOptionsProto"; + +message ConditionedImageGraphOptions { + // For conditioned image graph based on face landmarks. + message FaceConditionTypeOptions { + // Options for the face landmarker used in the face landmarks type graph. + face_landmarker.proto.FaceLandmarkerGraphOptions + face_landmarker_graph_options = 1; + } + + // For conditioned image graph base on edges detection. + message EdgeConditionTypeOptions { + // These parameters are used to config Canny edge algorithm of OpenCV. + // See more details: + // https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de + + // First threshold for the hysteresis procedure. + float threshold_1 = 1; + + // Second threshold for the hysteresis procedure. + float threshold_2 = 2; + + // Aperture size for the Sobel operator. Typical range is 3~7. + int32 aperture_size = 3; + + // A flag, indicating whether a more accurate L2 norm should be used to + // calculate the image gradient magnitude ( L2gradient=true ), or whether + // the default L1 norm is enough ( L2gradient=false ). + bool l2_gradient = 4; + } + + // For conditioned image graph base on depth map. + message DepthConditionTypeOptions { + // Options for the image segmenter used in the depth condition type graph. + image_segmenter.proto.ImageSegmenterGraphOptions + image_segmenter_graph_options = 1; + } + + // The options for configuring the conditioned image graph. + oneof condition_type_options { + FaceConditionTypeOptions face_condition_type_options = 2; + EdgeConditionTypeOptions edge_condition_type_options = 3; + DepthConditionTypeOptions depth_condition_type_options = 4; + } +} diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto new file mode 100644 index 000000000..52d94efb3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto @@ -0,0 +1,34 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.image_generator.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; +option java_outer_classname = "ControlPluginGraphOptionsProto"; + +message ControlPluginGraphOptions { + // The base options for the control plugin model. + core.proto.BaseOptions base_options = 1; + + // The options for the ConditionedImageGraphOptions to generate control plugin + // model input image. + proto.ConditionedImageGraphOptions conditioned_image_graph_options = 2; +} diff --git a/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto b/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto new file mode 100644 index 000000000..867080dc3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.proto @@ -0,0 +1,35 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.vision.image_generator.proto; + +import "mediapipe/tasks/cc/core/proto/external_file.proto"; +import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto"; +option java_outer_classname = "ImageGeneratorGraphOptionsProto"; + +message ImageGeneratorGraphOptions { + // The directory containing the models weight of the text to image model. + string text2image_model_directory = 1; + + // An optional LoRA weights file. If set, the diffusion model will be created + // with LoRA weights. + core.proto.ExternalFile lora_weights_file = 2; + + repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3; +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 0c3500274..0fc4a4974 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -59,6 +59,22 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", ] +_VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_geometry/proto:mesh_3d_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", +] + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", @@ -249,6 +265,39 @@ EOF native_library = native_library, ) +def mediapipe_tasks_vision_image_generator_aar(name, srcs, native_library): + """Builds medaipipe tasks vision image generator AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Vision Tasks' source files. + native_library: The native library that contains image generator task's graph and calculators. + """ + + native.genrule( + name = name + "tasks_manifest_generator", + outs = ["AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + + +EOF +""", + ) + + _mediapipe_tasks_aar( + name = name, + srcs = srcs, + manifest = "AndroidManifest.xml", + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS, + native_library = native_library, + ) + def mediapipe_tasks_text_aar(name, srcs, native_library): """Builds medaipipe tasks text AAR. @@ -344,6 +393,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", + "@com_google_protobuf//:protobuf_javalite", ] + select({ "//conditions:default": [":" + name + "_jni_opencv_cc_lib"], "//mediapipe/framework/port:disable_opencv": [], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index aab542842..1ddcd46c4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -413,6 +413,9 @@ load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl" mediapipe_tasks_vision_aar( name = "tasks_vision", - srcs = glob(["**/*.java"]), + srcs = glob( + ["**/*.java"], + exclude = ["imagegenerator/**"], + ), native_library = ":libmediapipe_tasks_vision_jni_lib", ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/AndroidManifest.xml new file mode 100644 index 000000000..5645810d2 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD new file mode 100644 index 000000000..5a460009a --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/BUILD @@ -0,0 +1,84 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +# The native library of MediaPipe vision image generator tasks. +cc_binary( + name = "libmediapipe_tasks_vision_image_generator_jni.so", + linkopts = [ + "-Wl,--no-undefined", + "-Wl,--version-script,$(location //mediapipe/tasks/java:version_script.lds)", + ], + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", + "//mediapipe/tasks/cc/vision/image_generator:image_generator_graph", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", + "//mediapipe/tasks/java:version_script.lds", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + ], +) + +cc_library( + name = "libmediapipe_tasks_vision_image_generator_jni_lib", + srcs = [":libmediapipe_tasks_vision_image_generator_jni.so"], + alwayslink = 1, +) + +android_library( + name = "imagegenerator", + srcs = [ + "ImageGenerator.java", + "ImageGeneratorResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core_java", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:facelandmarker", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:imagesegmenter", + "//third_party:any_java_proto", + "//third_party:autovalue", + "//third_party/java/protobuf:protobuf_lite", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_image_generator_aar") + +mediapipe_tasks_vision_image_generator_aar( + name = "tasks_vision_image_generator", + srcs = glob(["**/*.java"]), + native_library = ":libmediapipe_tasks_vision_image_generator_jni_lib", +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java new file mode 100644 index 000000000..1de8e4c46 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGenerator.java @@ -0,0 +1,660 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagegenerator; + +import android.content.Context; +import android.graphics.Bitmap; +import android.util.Log; +import androidx.annotation.Nullable; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskResult; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.ExternalFileProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facelandmarker.FaceLandmarker.FaceLandmarkerOptions; +import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions; +import com.google.mediapipe.tasks.vision.imagegenerator.proto.ConditionedImageGraphOptionsProto.ConditionedImageGraphOptions; +import com.google.mediapipe.tasks.vision.imagegenerator.proto.ControlPluginGraphOptionsProto; +import com.google.mediapipe.tasks.vision.imagegenerator.proto.ImageGeneratorGraphOptionsProto; +import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions; +import com.google.protobuf.Any; +import com.google.protobuf.ExtensionRegistryLite; +import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** Performs image generation from a text prompt. */ +public final class ImageGenerator extends BaseVisionTaskApi { + + private static final String STEPS_STREAM_NAME = "steps"; + private static final String ITERATION_STREAM_NAME = "iteration"; + private static final String PROMPT_STREAM_NAME = "prompt"; + private static final String RAND_SEED_STREAM_NAME = "rand_seed"; + private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image"; + private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image"; + private static final String SELECT_STREAM_NAME = "select"; + private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0; + private static final int STEPS_OUT_STREAM_INDEX = 1; + private static final int ITERATION_OUT_STREAM_INDEX = 2; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph"; + private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME = + "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer"; + private static final String TAG = "ImageGenerator"; + private TaskRunner conditionImageGraphsContainerTaskRunner; + private Map conditionTypeIndex; + private boolean useConditionImage = false; + + /** + * Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}. + * + * @param context an Android {@link Context}. + * @param generatorOptions an {@link ImageGeneratorOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageGenerator} creation. + */ + public static ImageGenerator createFromOptions( + Context context, ImageGeneratorOptions generatorOptions) { + return createFromOptions(context, generatorOptions, null); + } + + /** + * Creates an {@link ImageGenerator} instance, from {@link ImageGeneratorOptions} and {@link + * ConditionOptions}, if plugin models are used to generate an image based on the condition image. + * + * @param context an Android {@link Context}. + * @param generatorOptions an {@link ImageGeneratorOptions} instance. + * @param conditionOptions an {@link ConditionOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageGenerator} creation. + */ + public static ImageGenerator createFromOptions( + Context context, + ImageGeneratorOptions generatorOptions, + @Nullable ConditionOptions conditionOptions) { + List inputStreams = new ArrayList<>(); + inputStreams.addAll( + Arrays.asList( + "STEPS:" + STEPS_STREAM_NAME, + "ITERATION:" + ITERATION_STREAM_NAME, + "PROMPT:" + PROMPT_STREAM_NAME, + "RAND_SEED:" + RAND_SEED_STREAM_NAME)); + final boolean useConditionImage = conditionOptions != null; + if (useConditionImage) { + inputStreams.add("SELECT:" + SELECT_STREAM_NAME); + inputStreams.add("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME); + generatorOptions.conditionOptions = Optional.of(conditionOptions); + } + List outputStreams = + Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out"); + + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + @Nullable + public ImageGeneratorResult convertToTaskResult(List packets) { + int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX)); + int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX)); + Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps); + if (iteration != steps - 1) { + return null; + } + Log.i("ImageGenerator", "processing generated image"); + Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX); + Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet); + BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap); + return ImageGeneratorResult.create( + bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + handler.setHandleTimestampBoundChanges(true); + if (generatorOptions.resultListener().isPresent()) { + ResultListener resultListener = + new ResultListener() { + @Override + public void run(ImageGeneratorResult imageGeneratorResult, Void input) { + generatorOptions.resultListener().get().run(imageGeneratorResult); + } + }; + handler.setResultListener(resultListener); + } + generatorOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(ImageGenerator.class.getSimpleName()) + .setTaskRunningModeName(RunningMode.IMAGE.name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(inputStreams) + .setOutputStreams(outputStreams) + .setTaskOptions(generatorOptions) + .setEnableFlowLimiting(false) + .build(), + handler); + ImageGenerator imageGenerator = new ImageGenerator(runner); + if (useConditionImage) { + imageGenerator.useConditionImage = true; + inputStreams = + Arrays.asList( + "IMAGE:" + SOURCE_CONDITION_IMAGE_STREAM_NAME, "SELECT:" + SELECT_STREAM_NAME); + outputStreams = Arrays.asList("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME); + OutputHandler conditionImageHandler = new OutputHandler<>(); + conditionImageHandler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ConditionImageResult convertToTaskResult(List packets) { + Packet packet = packets.get(0); + return new AutoValue_ImageGenerator_ConditionImageResult( + new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(packet)).build(), + packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND); + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + conditionImageHandler.setHandleTimestampBoundChanges(true); + imageGenerator.conditionImageGraphsContainerTaskRunner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(ImageGenerator.class.getSimpleName()) + .setTaskRunningModeName(RunningMode.IMAGE.name()) + .setTaskGraphName(CONDITION_IMAGE_GRAPHS_CONTAINER_NAME) + .setInputStreams(inputStreams) + .setOutputStreams(outputStreams) + .setTaskOptions(generatorOptions) + .setEnableFlowLimiting(false) + .build(), + conditionImageHandler); + imageGenerator.conditionTypeIndex = new HashMap<>(); + if (conditionOptions.faceConditionOptions().isPresent()) { + imageGenerator.conditionTypeIndex.put( + ConditionOptions.ConditionType.FACE, imageGenerator.conditionTypeIndex.size()); + } + if (conditionOptions.edgeConditionOptions().isPresent()) { + imageGenerator.conditionTypeIndex.put( + ConditionOptions.ConditionType.EDGE, imageGenerator.conditionTypeIndex.size()); + } + if (conditionOptions.depthConditionOptions().isPresent()) { + imageGenerator.conditionTypeIndex.put( + ConditionOptions.ConditionType.DEPTH, imageGenerator.conditionTypeIndex.size()); + } + } + return imageGenerator; + } + + private ImageGenerator(TaskRunner taskRunner) { + super(taskRunner, RunningMode.IMAGE, "", ""); + } + + /** + * Generates an image for iterations and the given random seed. Only valid when the ImageGenerator + * is created without condition options. + * + * @param prompt The text prompt describing the image to be generated. + * @param iterations The total iterations to generate the image. + * @param seed The random seed used during image generation. + */ + public ImageGeneratorResult generate(String prompt, int iterations, int seed) { + return runIterations(prompt, iterations, seed, null, 0); + } + + /** + * Generates an image based on the source image for iterations and the given random seed. Only + * valid when the ImageGenerator is created with condition options. + * + * @param prompt The text prompt describing the image to be generated. + * @param sourceConditionImage The source image used to create the condition image, which is used + * as a guidance for the image generation. + * @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of + * condition image. + * @param iterations The total iterations to generate the image. + * @param seed The random seed used during image generation. + */ + public ImageGeneratorResult generate( + String prompt, + MPImage sourceConditionImage, + ConditionOptions.ConditionType conditionType, + int iterations, + int seed) { + return runIterations( + prompt, + iterations, + seed, + createConditionImage(sourceConditionImage, conditionType), + conditionTypeIndex.get(conditionType)); + } + + /** + * Create the condition image of specified condition type from the source image. Currently support + * face landmarks, depth image and edge image as the condition image. + * + * @param sourceConditionImage The source image used to create the condition image. + * @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of + * condition image. + */ + public MPImage createConditionImage( + MPImage sourceConditionImage, ConditionOptions.ConditionType conditionType) { + if (!conditionTypeIndex.containsKey(conditionType)) { + throw new IllegalArgumentException( + "The condition type " + conditionType.name() + " is not created during initialization."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put( + SOURCE_CONDITION_IMAGE_STREAM_NAME, + conditionImageGraphsContainerTaskRunner + .getPacketCreator() + .createImage(sourceConditionImage)); + inputPackets.put( + SELECT_STREAM_NAME, + conditionImageGraphsContainerTaskRunner + .getPacketCreator() + .createInt32(conditionTypeIndex.get(conditionType))); + ConditionImageResult result = + (ConditionImageResult) conditionImageGraphsContainerTaskRunner.process(inputPackets); + return result.conditionImage(); + } + + private ImageGeneratorResult runIterations( + String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) { + ImageGeneratorResult result = null; + long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND; + for (int i = 0; i < steps; i++) { + Map inputPackets = new HashMap<>(); + if (i == 0 && useConditionImage) { + inputPackets.put( + CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage)); + inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select)); + } + inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt)); + inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(steps)); + inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(i)); + inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed)); + result = (ImageGeneratorResult) runner.process(inputPackets, timestamp++); + } + if (useConditionImage) { + // Add condition image to the ImageGeneratorResult. + return ImageGeneratorResult.create( + result.generatedImage(), conditionImage, result.timestampMs()); + } + return result; + } + + /** Closes and cleans up the task runners. */ + @Override + public void close() { + runner.close(); + conditionImageGraphsContainerTaskRunner.close(); + } + + /** A container class for the condition image. */ + @AutoValue + protected abstract static class ConditionImageResult implements TaskResult { + + public abstract MPImage conditionImage(); + + @Override + public abstract long timestampMs(); + } + + /** Options for setting up an {@link ImageGenerator}. */ + @AutoValue + public abstract static class ImageGeneratorOptions extends TaskOptions { + + /** Builder for {@link ImageGeneratorOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + /** Sets the text to image model directory storing the model weights. */ + public abstract Builder setText2ImageModelDirectory(String modelDirectory); + + /** Sets the path to LoRA weights file. */ + public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath); + + public abstract Builder setResultListener( + PureResultListener resultListener); + + /** Sets an optional {@link ErrorListener}}. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract ImageGeneratorOptions autoBuild(); + + /** Validates and builds the {@link ImageGeneratorOptions} instance. */ + public final ImageGeneratorOptions build() { + return autoBuild(); + } + } + + abstract String text2ImageModelDirectory(); + + abstract Optional loraWeightsFilePath(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + private Optional conditionOptions; + + public static Builder builder() { + return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder() + .setText2ImageModelDirectory(""); + } + + /** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */ + @Override + public Any convertToAnyProto() { + ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder = + ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder(); + if (conditionOptions != null && conditionOptions.isPresent()) { + try { + taskOptionsBuilder.mergeFrom( + conditionOptions.get().convertToAnyProto().getValue(), + ExtensionRegistryLite.getGeneratedRegistry()); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error converting ConditionOptions to proto. " + e.getMessage()); + e.printStackTrace(); + } + } + taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory()); + if (loraWeightsFilePath().isPresent()) { + ExternalFileProto.ExternalFile.Builder externalFileBuilder = + ExternalFileProto.ExternalFile.newBuilder(); + externalFileBuilder.setFileName(loraWeightsFilePath().get()); + taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build()); + } + return Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions") + .setValue(taskOptionsBuilder.build().toByteString()) + .build(); + } + } + + /** Options for setting up the conditions types and the plugin models */ + @AutoValue + public abstract static class ConditionOptions extends TaskOptions { + + /** The supported condition type. */ + public enum ConditionType { + FACE, + EDGE, + DEPTH + } + + /** Builder for {@link ConditionOptions}. At least one type of condition options must be set. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setFaceConditionOptions(FaceConditionOptions faceConditionOptions); + + public abstract Builder setDepthConditionOptions(DepthConditionOptions depthConditionOptions); + + public abstract Builder setEdgeConditionOptions(EdgeConditionOptions edgeConditionOptions); + + abstract ConditionOptions autoBuild(); + + /** Validates and builds the {@link ConditionOptions} instance. */ + public final ConditionOptions build() { + ConditionOptions options = autoBuild(); + if (!options.faceConditionOptions().isPresent() + && !options.depthConditionOptions().isPresent() + && !options.edgeConditionOptions().isPresent()) { + throw new IllegalArgumentException( + "At least one of `faceConditionOptions`, `depthConditionOptions` and" + + " `edgeConditionOptions` must be set."); + } + return options; + } + } + + abstract Optional faceConditionOptions(); + + abstract Optional depthConditionOptions(); + + abstract Optional edgeConditionOptions(); + + public static Builder builder() { + return new AutoValue_ImageGenerator_ConditionOptions.Builder(); + } + + /** + * Converts an {@link ImageGeneratorOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public Any convertToAnyProto() { + ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder = + ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder(); + if (faceConditionOptions().isPresent()) { + taskOptionsBuilder.addControlPluginGraphsOptions( + ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() + .setBaseOptions( + convertBaseOptionsToProto(faceConditionOptions().get().baseOptions())) + .setConditionedImageGraphOptions( + ConditionedImageGraphOptions.newBuilder() + .setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto()) + .build()) + .build()); + } + if (edgeConditionOptions().isPresent()) { + taskOptionsBuilder.addControlPluginGraphsOptions( + ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() + .setBaseOptions( + convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions())) + .setConditionedImageGraphOptions( + ConditionedImageGraphOptions.newBuilder() + .setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto()) + .build()) + .build()); + if (depthConditionOptions().isPresent()) { + taskOptionsBuilder.addControlPluginGraphsOptions( + ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder() + .setBaseOptions( + convertBaseOptionsToProto(depthConditionOptions().get().baseOptions())) + .setConditionedImageGraphOptions( + ConditionedImageGraphOptions.newBuilder() + .setDepthConditionTypeOptions( + depthConditionOptions().get().convertToProto()) + .build()) + .build()); + } + } + return Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions") + .setValue(taskOptionsBuilder.build().toByteString()) + .build(); + } + + /** Options for drawing face landmarks image. */ + @AutoValue + public abstract static class FaceConditionOptions extends TaskOptions { + + /** Builder for {@link FaceConditionOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Set the base options for plugin model. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */ + public abstract Builder setFaceLandmarkerOptions( + FaceLandmarkerOptions faceLandmarkerOptions); + + abstract FaceConditionOptions autoBuild(); + + /** Validates and builds the {@link FaceConditionOptions} instance. */ + public final FaceConditionOptions build() { + return autoBuild(); + } + } + + abstract BaseOptions baseOptions(); + + abstract FaceLandmarkerOptions faceLandmarkerOptions(); + + public static Builder builder() { + return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder(); + } + + ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() { + return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder() + .setFaceLandmarkerGraphOptions( + FaceLandmarkerGraphOptions.newBuilder() + .mergeFrom( + faceLandmarkerOptions() + .convertToCalculatorOptionsProto() + .getExtension(FaceLandmarkerGraphOptions.ext)) + .build()) + .build(); + } + } + + /** Options for detecting depth image. */ + @AutoValue + public abstract static class DepthConditionOptions extends TaskOptions { + + /** Builder for {@link DepthConditionOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + /** Set the base options for plugin model. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** {@link ImageSegmenterOptions} used to detect depth image from the source image. */ + public abstract Builder setImageSegmenterOptions( + ImageSegmenterOptions imageSegmenterOptions); + + abstract DepthConditionOptions autoBuild(); + + /** Validates and builds the {@link DepthConditionOptions} instance. */ + public final DepthConditionOptions build() { + DepthConditionOptions options = autoBuild(); + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract ImageSegmenterOptions imageSegmenterOptions(); + + public static Builder builder() { + return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder(); + } + + ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() { + return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder() + .setImageSegmenterGraphOptions( + imageSegmenterOptions() + .convertToCalculatorOptionsProto() + .getExtension(ImageSegmenterGraphOptions.ext)) + .build(); + } + } + + /** Options for detecting edge image. */ + @AutoValue + public abstract static class EdgeConditionOptions { + + /** + * Builder for {@link EdgeConditionOptions}. + * + *

These parameters are used to config Canny edge algorithm of OpenCV. + * + *

See more details: + * https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de + */ + @AutoValue.Builder + public abstract static class Builder { + + /** Set the base options for plugin model. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** First threshold for the hysteresis procedure. */ + public abstract Builder setThreshold1(Float threshold1); + + /** Second threshold for the hysteresis procedure. */ + public abstract Builder setThreshold2(Float threshold2); + + /** Aperture size for the Sobel operator. Typical range is 3~7. */ + public abstract Builder setApertureSize(Integer apertureSize); + + /** + * flag, indicating whether a more accurate L2 norm should be used to calculate the image + * gradient magnitude ( L2gradient=true ), or whether the default L1 norm is enough ( + * L2gradient=false ). + */ + public abstract Builder setL2Gradient(Boolean l2Gradient); + + abstract EdgeConditionOptions autoBuild(); + + /** Validates and builds the {@link EdgeConditionOptions} instance. */ + public final EdgeConditionOptions build() { + return autoBuild(); + } + } + + abstract BaseOptions baseOptions(); + + abstract Float threshold1(); + + abstract Float threshold2(); + + abstract Integer apertureSize(); + + abstract Boolean l2Gradient(); + + public static Builder builder() { + return new AutoValue_ImageGenerator_ConditionOptions_EdgeConditionOptions.Builder() + .setThreshold1(100f) + .setThreshold2(200f) + .setApertureSize(3) + .setL2Gradient(false); + } + + ConditionedImageGraphOptions.EdgeConditionTypeOptions convertToProto() { + return ConditionedImageGraphOptions.EdgeConditionTypeOptions.newBuilder() + .setThreshold1(threshold1()) + .setThreshold2(threshold2()) + .setApertureSize(apertureSize()) + .setL2Gradient(l2Gradient()) + .build(); + } + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGeneratorResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGeneratorResult.java new file mode 100644 index 000000000..6bb3ab60e --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagegenerator/ImageGeneratorResult.java @@ -0,0 +1,44 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imagegenerator; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.Optional; + +/** Represents the image generation results generated by {@link ImageGenerator}. */ +@AutoValue +public abstract class ImageGeneratorResult implements TaskResult { + + /** Create an {@link ImageGeneratorResult} instance from the generated image. */ + public static ImageGeneratorResult create( + MPImage generatedImage, MPImage conditionImage, long timestampMs) { + return new AutoValue_ImageGeneratorResult( + generatedImage, Optional.of(conditionImage), timestampMs); + } + + /** Create an {@link ImageGeneratorResult} instance from the generated image. */ + public static ImageGeneratorResult create(MPImage generatedImage, long timestampMs) { + return new AutoValue_ImageGeneratorResult(generatedImage, Optional.empty(), timestampMs); + } + + public abstract MPImage generatedImage(); + + public abstract Optional conditionImage(); + + @Override + public abstract long timestampMs(); +}