Internal update

PiperOrigin-RevId: 561148365
This commit is contained in:
MediaPipe Team 2023-08-29 15:02:23 -07:00 committed by Copybara-Service
parent 01fbbd9f67
commit e18e749e3e
22 changed files with 3321 additions and 1 deletions

View File

@ -0,0 +1,136 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library(
name = "conditioned_image_graph",
srcs = ["conditioned_image_graph.cc"],
deps = [
"//mediapipe/calculators/core:get_vector_item_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator_cc_proto",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_connections",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:image_frame_util",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "image_generator_graph",
srcs = ["image_generator_graph.cc"],
deps = [
":conditioned_image_graph",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/image:image_transformation_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:stream_handler_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container",
"//mediapipe/framework/tool:switch_container_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:diffusion_plugins_output_calculator",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator",
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_library(
name = "image_generator_result",
hdrs = ["image_generator_result.h"],
deps = ["//mediapipe/framework/formats:image"],
)
cc_library(
name = "image_generator",
srcs = ["image_generator.cc"],
hdrs = ["image_generator.h"],
deps = [
":image_generator_graph",
":image_generator_result",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,458 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <optional>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.h"
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/image_frame_util.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_generator {
namespace internal {
// Helper postprocessing calculator for depth condition type to scale raw depth
// inference result to 0-255 uint8.
class DepthImagePostprocessingCalculator : public api2::Node {
public:
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
static constexpr api2::Output<Image> kImageOut{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut);
absl::Status Process(CalculatorContext* cc) final {
if (kImageIn(cc).IsEmpty()) {
return absl::OkStatus();
}
Image raw_depth_image = kImageIn(cc).Get();
cv::Mat raw_depth_mat = mediapipe::formats::MatView(
raw_depth_image.GetImageFrameSharedPtr().get());
cv::Mat depth_mat;
cv::normalize(raw_depth_mat, depth_mat, 255, 0, cv::NORM_MINMAX);
depth_mat.convertTo(depth_mat, CV_8UC3, 1, 0);
cv::cvtColor(depth_mat, depth_mat, cv::COLOR_GRAY2RGB);
// Acquires the cv::Mat data and assign to the image frame.
ImageFrameSharedPtr depth_image_frame_ptr = std::make_shared<ImageFrame>(
mediapipe::ImageFormat::SRGB, depth_mat.cols, depth_mat.rows,
depth_mat.step, depth_mat.data,
[depth_mat](uint8_t[]) { depth_mat.~Mat(); });
Image depth_image(depth_image_frame_ptr);
kImageOut(cc).Send(depth_image);
return absl::OkStatus();
}
};
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
// moved to next line.
// clang-format off
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::DepthImagePostprocessingCalculator);
// clang-format on
// NOLINTEND
// Calculator to detect edges in the image with OpenCV Canny edge detection.
class CannyEdgeCalculator : public api2::Node {
public:
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
static constexpr api2::Output<Image> kImageOut{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut);
absl::Status Process(CalculatorContext* cc) final {
if (kImageIn(cc).IsEmpty()) {
return absl::OkStatus();
}
Image input_image = kImageIn(cc).Get();
cv::Mat input_image_mat =
mediapipe::formats::MatView(input_image.GetImageFrameSharedPtr().get());
const auto& options = cc->Options<
proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>();
cv::Mat lumincance;
cv::cvtColor(input_image_mat, lumincance, cv::COLOR_RGB2GRAY);
cv::Mat edges_mat;
cv::Canny(lumincance, edges_mat, options.threshold_1(),
options.threshold_2(), options.aperture_size(),
options.l2_gradient());
cv::normalize(edges_mat, edges_mat, 255, 0, cv::NORM_MINMAX);
edges_mat.convertTo(edges_mat, CV_8UC3, 1, 0);
cv::cvtColor(edges_mat, edges_mat, cv::COLOR_GRAY2RGB);
// Acquires the cv::Mat data and assign to the image frame.
ImageFrameSharedPtr edges_image_frame_ptr = std::make_shared<ImageFrame>(
mediapipe::ImageFormat::SRGB, edges_mat.cols, edges_mat.rows,
edges_mat.step, edges_mat.data,
[edges_mat](uint8_t[]) { edges_mat.~Mat(); });
Image edges_image(edges_image_frame_ptr);
kImageOut(cc).Send(edges_image);
return absl::OkStatus();
}
};
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
// moved to next line.
// clang-format off
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::CannyEdgeCalculator);
// clang-format on
// NOLINTEND
} // namespace internal
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
constexpr absl::string_view kImageTag = "IMAGE";
constexpr absl::string_view kUImageTag = "UIMAGE";
constexpr absl::string_view kNormLandmarksTag = "NORM_LANDMARKS";
constexpr absl::string_view kVectorTag = "VECTOR";
constexpr absl::string_view kItemTag = "ITEM";
constexpr absl::string_view kRenderDataTag = "RENDER_DATA";
constexpr absl::string_view kConfidenceMaskTag = "CONFIDENCE_MASK:0";
enum ColorType {
WHITE = 0,
GREEN = 1,
RED = 2,
BLACK = 3,
BLUE = 4,
};
mediapipe::Color GetColor(ColorType color_type) {
mediapipe::Color color;
switch (color_type) {
case WHITE:
color.set_b(255);
color.set_g(255);
color.set_r(255);
break;
case GREEN:
color.set_b(0);
color.set_g(255);
color.set_r(0);
break;
case RED:
color.set_b(0);
color.set_g(0);
color.set_r(255);
break;
case BLACK:
color.set_b(0);
color.set_g(0);
color.set_r(0);
break;
case BLUE:
color.set_b(255);
color.set_g(0);
color.set_r(0);
break;
}
return color;
}
// Get LandmarksToRenderDataCalculatorOptions for rendering face landmarks
// connections.
mediapipe::LandmarksToRenderDataCalculatorOptions
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>> connections, ColorType color_type) {
mediapipe::LandmarksToRenderDataCalculatorOptions render_options;
render_options.set_thickness(1);
render_options.set_visualize_landmark_depth(false);
render_options.set_render_landmarks(false);
*render_options.mutable_connection_color() = GetColor(color_type);
for (const auto& connection : connections) {
render_options.add_landmark_connections(connection[0]);
render_options.add_landmark_connections(connection[1]);
}
return render_options;
}
Source<mediapipe::RenderData> GetFaceLandmarksRenderData(
Source<mediapipe::NormalizedLandmarkList> face_landmarks,
const mediapipe::LandmarksToRenderDataCalculatorOptions&
landmarks_to_render_data_options,
Graph& graph) {
auto& landmarks_to_render_data =
graph.AddNode("LandmarksToRenderDataCalculator");
landmarks_to_render_data
.GetOptions<mediapipe::LandmarksToRenderDataCalculatorOptions>()
.CopyFrom(landmarks_to_render_data_options);
face_landmarks >> landmarks_to_render_data.In(kNormLandmarksTag);
return landmarks_to_render_data.Out(kRenderDataTag)
.Cast<mediapipe::RenderData>();
}
// Add FaceLandmarkerGraph to detect the face landmarks in the given face image,
// and generate a face mesh guidance image for the diffusion plugin model.
absl::StatusOr<Source<Image>> GetFaceLandmarksImage(
Source<Image> face_image,
const proto::ConditionedImageGraphOptions::FaceConditionTypeOptions&
face_condition_type_options,
Graph& graph) {
if (face_condition_type_options.face_landmarker_graph_options()
.face_detector_graph_options()
.num_faces() != 1) {
return absl::InvalidArgumentError(
"Only supports face landmarks of a single face as the guidance image.");
}
// Detect face landmarks.
auto& face_landmarker_graph = graph.AddNode(
"mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph");
face_landmarker_graph
.GetOptions<face_landmarker::proto::FaceLandmarkerGraphOptions>()
.CopyFrom(face_condition_type_options.face_landmarker_graph_options());
face_image >> face_landmarker_graph.In(kImageTag);
auto face_landmarks_lists =
face_landmarker_graph.Out(kNormLandmarksTag)
.Cast<std::vector<mediapipe::NormalizedLandmarkList>>();
// Get the single face landmarks.
auto& get_vector_item =
graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator");
get_vector_item.GetOptions<mediapipe::GetVectorItemCalculatorOptions>()
.set_item_index(0);
face_landmarks_lists >> get_vector_item.In(kVectorTag);
auto single_face_landmarks =
get_vector_item.Out(kItemTag).Cast<mediapipe::NormalizedLandmarkList>();
// Convert face landmarks to render data.
auto face_oval = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval
.data(),
face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval
.size()),
ColorType::WHITE),
graph);
auto lips = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips
.data(),
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips
.size()),
ColorType::WHITE),
graph);
auto left_eye = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye
.data(),
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye
.size()),
ColorType::GREEN),
graph);
auto left_eye_brow = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::
kFaceLandmarksLeftEyeBrow.data(),
face_landmarker::FaceLandmarksConnections::
kFaceLandmarksLeftEyeBrow.size()),
ColorType::GREEN),
graph);
auto left_iris = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris
.data(),
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris
.size()),
ColorType::GREEN),
graph);
auto right_eye = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye
.data(),
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye
.size()),
ColorType::BLUE),
graph);
auto right_eye_brow = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::
kFaceLandmarksRightEyeBrow.data(),
face_landmarker::FaceLandmarksConnections::
kFaceLandmarksRightEyeBrow.size()),
ColorType::BLUE),
graph);
auto right_iris = GetFaceLandmarksRenderData(
single_face_landmarks,
GetFaceLandmarksRenderDataOptions(
absl::Span<const std::array<int, 2>>(
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris
.data(),
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris
.size()),
ColorType::BLUE),
graph);
// Create a black canvas image with same size as face image.
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
flat_color.GetOptions<mediapipe::FlatColorImageCalculatorOptions>()
.mutable_color()
->set_r(0);
face_image >> flat_color.In(kImageTag);
auto blank_canvas = flat_color.Out(kImageTag);
// Draw render data on the canvas image.
auto& annotation_overlay = graph.AddNode("AnnotationOverlayCalculator");
blank_canvas >> annotation_overlay.In(kUImageTag);
face_oval >> annotation_overlay.In(0);
lips >> annotation_overlay.In(1);
left_eye >> annotation_overlay.In(2);
left_eye_brow >> annotation_overlay.In(3);
left_iris >> annotation_overlay.In(4);
right_eye >> annotation_overlay.In(5);
right_eye_brow >> annotation_overlay.In(6);
right_iris >> annotation_overlay.In(7);
return annotation_overlay.Out(kUImageTag).Cast<Image>();
}
absl::StatusOr<Source<Image>> GetDepthImage(
Source<Image> image,
const image_generator::proto::ConditionedImageGraphOptions::
DepthConditionTypeOptions& depth_condition_type_options,
Graph& graph) {
auto& image_segmenter_graph = graph.AddNode(
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
image_segmenter_graph
.GetOptions<image_segmenter::proto::ImageSegmenterGraphOptions>()
.CopyFrom(depth_condition_type_options.image_segmenter_graph_options());
image >> image_segmenter_graph.In(kImageTag);
auto raw_depth_image = image_segmenter_graph.Out(kConfidenceMaskTag);
auto& depth_postprocessing = graph.AddNode(
"mediapipe.tasks.vision.image_generator.internal."
"DepthImagePostprocessingCalculator");
raw_depth_image >> depth_postprocessing.In(kImageTag);
return depth_postprocessing.Out(kImageTag).Cast<Image>();
}
absl::StatusOr<Source<Image>> GetEdgeImage(
Source<Image> image,
const image_generator::proto::ConditionedImageGraphOptions::
EdgeConditionTypeOptions& edge_condition_type_options,
Graph& graph) {
auto& edge_detector = graph.AddNode(
"mediapipe.tasks.vision.image_generator.internal."
"CannyEdgeCalculator");
edge_detector
.GetOptions<
proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>()
.CopyFrom(edge_condition_type_options);
image >> edge_detector.In(kImageTag);
return edge_detector.Out(kImageTag).Cast<Image>();
}
} // namespace
// A mediapipe.tasks.vision.image_generator.ConditionedImageGraph converts the
// input image to an image of condition type. The output image can be used as
// input for the diffusion model with control plugin.
// Inputs:
// IMAGE - Image
// Conditioned image to generate the image for diffusion plugin model.
//
// Outputs:
// IMAGE - Image
// The guidance image used as input for the diffusion plugin model.
class ConditionedImageGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
auto& graph_options =
*sc->MutableOptions<proto::ConditionedImageGraphOptions>();
Source<Image> conditioned_image = graph.In(kImageTag).Cast<Image>();
// Configure the guidance graph and get the guidance image if guidance graph
// options is set.
switch (graph_options.condition_type_options_case()) {
case proto::ConditionedImageGraphOptions::CONDITION_TYPE_OPTIONS_NOT_SET:
return absl::InvalidArgumentError(
"Conditioned type options is not set.");
break;
case proto::ConditionedImageGraphOptions::kFaceConditionTypeOptions: {
ASSIGN_OR_RETURN(
auto face_landmarks_image,
GetFaceLandmarksImage(conditioned_image,
graph_options.face_condition_type_options(),
graph));
face_landmarks_image >> graph.Out(kImageTag);
} break;
case proto::ConditionedImageGraphOptions::kDepthConditionTypeOptions: {
ASSIGN_OR_RETURN(
auto depth_image,
GetDepthImage(conditioned_image,
graph_options.depth_condition_type_options(), graph));
depth_image >> graph.Out(kImageTag);
} break;
case proto::ConditionedImageGraphOptions::kEdgeConditionTypeOptions: {
ASSIGN_OR_RETURN(
auto edges_image,
GetEdgeImage(conditioned_image,
graph_options.edge_condition_type_options(), graph));
edges_image >> graph.Out(kImageTag);
} break;
}
return graph.GetConfig();
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_generator::ConditionedImageGraph);
} // namespace image_generator
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,147 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_generator {
namespace {
using ::mediapipe::Image;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kFaceLandmarkerModel[] = "face_landmarker_v2.task";
constexpr char kDepthModel[] =
"mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt.tflite";
constexpr char kPortraitImage[] = "portrait.jpg";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStream[] = "image_in";
constexpr char kImageOutStream[] = "image_out";
// Helper function to create a ConditionedImageGraphTaskRunner TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>>
CreateConditionedImageGraphTaskRunner(
std::unique_ptr<proto::ConditionedImageGraphOptions> options) {
Graph graph;
auto& conditioned_image_graph = graph.AddNode(
"mediapipe.tasks.vision.image_generator.ConditionedImageGraph");
conditioned_image_graph.GetOptions<proto::ConditionedImageGraphOptions>()
.Swap(options.get());
graph.In(kImageTag).Cast<Image>().SetName(kImageInStream) >>
conditioned_image_graph.In(kImageTag);
conditioned_image_graph.Out(kImageTag).SetName(kImageOutStream) >>
graph.Out(kImageTag).Cast<Image>();
return core::TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
}
TEST(ConditionedImageGraphTest, SucceedsFaceLandmarkerConditionType) {
auto options = std::make_unique<proto::ConditionedImageGraphOptions>();
options->mutable_face_condition_type_options()
->mutable_face_landmarker_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
file::JoinPath("./", kTestDataDirectory, kFaceLandmarkerModel));
options->mutable_face_condition_type_options()
->mutable_face_landmarker_graph_options()
->mutable_face_detector_graph_options()
->set_num_faces(1);
MP_ASSERT_OK_AND_ASSIGN(
auto runner, CreateConditionedImageGraphTaskRunner(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory,
kPortraitImage)));
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
runner->Process({{kImageInStream, MakePacket<Image>(std::move(image))}}));
const auto& output_image = output_packets[kImageOutStream].Get<Image>();
MP_EXPECT_OK(SavePngTestOutput(*output_image.GetImageFrameSharedPtr(),
"face_landmarks_image"));
}
TEST(ConditionedImageGraphTest, SucceedsDepthConditionType) {
auto options = std::make_unique<proto::ConditionedImageGraphOptions>();
options->mutable_depth_condition_type_options()
->mutable_image_segmenter_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(file::JoinPath("./", kTestDataDirectory, kDepthModel));
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory,
kPortraitImage)));
MP_ASSERT_OK_AND_ASSIGN(
auto runner, CreateConditionedImageGraphTaskRunner(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
runner->Process({{kImageInStream, MakePacket<Image>(std::move(image))}}));
const auto& output_image = output_packets[kImageOutStream].Get<Image>();
MP_EXPECT_OK(
SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "depth_image"));
}
TEST(ConditionedImageGraphTest, SucceedsEdgeConditionType) {
auto options = std::make_unique<proto::ConditionedImageGraphOptions>();
auto edge_condition_type_options =
options->mutable_edge_condition_type_options();
edge_condition_type_options->set_threshold_1(100);
edge_condition_type_options->set_threshold_2(200);
edge_condition_type_options->set_aperture_size(3);
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory,
kPortraitImage)));
MP_ASSERT_OK_AND_ASSIGN(
auto runner, CreateConditionedImageGraphTaskRunner(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto output_packets,
runner->Process({{kImageInStream, MakePacket<Image>(std::move(image))}}));
const auto& output_image = output_packets[kImageOutStream].Get<Image>();
MP_EXPECT_OK(
SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "edges_image"));
}
} // namespace
} // namespace image_generator
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,70 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library(
name = "diffuser_gpu_header",
hdrs = ["diffuser_gpu.h"],
visibility = [
"//mediapipe/tasks/cc/vision/image_generator/diffuser:__pkg__",
],
)
mediapipe_proto_library(
name = "stable_diffusion_iterate_calculator_proto",
srcs = ["stable_diffusion_iterate_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "stable_diffusion_iterate_calculator",
srcs = ["stable_diffusion_iterate_calculator.cc"],
deps = [
":diffuser_gpu_header",
":stable_diffusion_iterate_calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/deps:file_helpers",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_library(
name = "diffusion_plugins_output_calculator",
srcs = ["diffusion_plugins_output_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)

View File

@ -0,0 +1,88 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_
#include <limits.h>
#include <stdint.h>
#ifndef DG_EXPORT
#define DG_EXPORT __attribute__((visibility("default")))
#endif // DG_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
enum DiffuserModelType {
kDiffuserModelTypeSd1,
kDiffuserModelTypeGldm,
kDiffuserModelTypeDistilledGldm,
kDiffuserModelTypeSd2Base,
kDiffuserModelTypeTigo,
};
enum DiffuserPriorityHint {
kDiffuserPriorityHintHigh,
kDiffuserPriorityHintNormal,
kDiffuserPriorityHintLow,
};
enum DiffuserPerformanceHint {
kDiffuserPerformanceHintHigh,
kDiffuserPerformanceHintNormal,
kDiffuserPerformanceHintLow,
};
typedef struct {
DiffuserPriorityHint priority_hint;
DiffuserPerformanceHint performance_hint;
} DiffuserEnvironmentOptions;
typedef struct {
DiffuserModelType model_type;
char model_dir[PATH_MAX];
char lora_dir[PATH_MAX];
const void* lora_weights_layer_mapping;
int lora_rank;
int seed;
int image_width;
int image_height;
int run_unet_with_plugins;
float plugins_strength;
DiffuserEnvironmentOptions env_options;
} DiffuserConfig;
typedef struct {
void* diffuser;
} DiffuserContext;
typedef struct {
int shape[4];
const float* data;
} DiffuserPluginTensor;
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
const char*, int, int, const void*);
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT
#ifdef __cplusplus
}
#endif // __cplusplus
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_

View File

@ -0,0 +1,67 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
namespace api2 {
// In iteration mode, output the image guidance tensors at the current timestamp
// and advance the output stream timestamp bound by the number of steps.
// Otherwise, output the image guidance tensors at the current timestamp only.
class DiffusionPluginsOutputCalculator : public Node {
public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Input<int> kStepsIn{"STEPS"};
static constexpr Input<int>::Optional kIterationIn{"ITERATION"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kStepsIn, kIterationIn, kTensorsOut);
absl::Status Process(CalculatorContext* cc) override {
if (kTensorsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
// Consumes the tensor vector to avoid data copy.
absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> status_or_tensor =
cc->Inputs().Tag("TENSORS").Value().Consume<std::vector<Tensor>>();
if (!status_or_tensor.ok()) {
return absl::InternalError("Input tensor vector is not consumable.");
}
if (kIterationIn(cc).IsConnected()) {
CHECK_EQ(kIterationIn(cc).Get(), 0);
kTensorsOut(cc).Send(std::move(*status_or_tensor.value()));
kTensorsOut(cc).SetNextTimestampBound(cc->InputTimestamp() +
kStepsIn(cc).Get());
} else {
kTensorsOut(cc).Send(std::move(*status_or_tensor.value()));
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(DiffusionPluginsOutputCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,278 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dlfcn.h>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_helpers.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h"
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
DiffuserPriorityHint ToDiffuserPriorityHint(
StableDiffusionIterateCalculatorOptions::ClPriorityHint priority) {
switch (priority) {
case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_LOW:
return kDiffuserPriorityHintLow;
case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_NORMAL:
return kDiffuserPriorityHintNormal;
case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_HIGH:
return kDiffuserPriorityHintHigh;
}
return kDiffuserPriorityHintNormal;
}
DiffuserModelType ToDiffuserModelType(
StableDiffusionIterateCalculatorOptions::ModelType model_type) {
switch (model_type) {
case StableDiffusionIterateCalculatorOptions::DEFAULT:
case StableDiffusionIterateCalculatorOptions::SD_1:
return kDiffuserModelTypeSd1;
}
return kDiffuserModelTypeSd1;
}
} // namespace
// Runs diffusion models including, but not limited to, Stable Diffusion & gLDM.
//
// Inputs:
// PROMPT - std::string
// The prompt used to generate the image.
// STEPS - int
// The number of steps to run the UNet.
// ITERATION - int
// The iteration of the current run.
// PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
// The output tensor vector of the diffusion plugins model.
//
// Outputs:
// IMAGE - mediapipe::ImageFrame
// The image generated by the Stable Diffusion model from the input prompt.
// The output image is in RGB format.
//
// Example:
// node {
// calculator: "StableDiffusionIterateCalculator"
// input_stream: "PROMPT:prompt"
// input_stream: "STEPS:steps"
// output_stream: "IMAGE:result"
// options {
// [mediapipe.StableDiffusionIterateCalculatorOptions.ext] {
// base_seed: 0
// model_type: SD_1
// }
// }
// }
class StableDiffusionIterateCalculator : public Node {
public:
static constexpr Input<std::string> kPromptIn{"PROMPT"};
static constexpr Input<int> kStepsIn{"STEPS"};
static constexpr Input<int>::Optional kIterationIn{"ITERATION"};
static constexpr Input<int>::Optional kRandSeedIn{"RAND_SEED"};
static constexpr SideInput<StableDiffusionIterateCalculatorOptions>::Optional
kOptionsIn{"OPTIONS"};
static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
"PLUGIN_TENSORS"};
static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
kPlugInTensorsIn, kOptionsIn, kImageOut);
~StableDiffusionIterateCalculator() {
if (context_) DiffuserDelete();
if (handle_) dlclose(handle_);
}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::vector<DiffuserPluginTensor> GetPluginTensors(
CalculatorContext* cc) const {
if (!kPlugInTensorsIn(cc).IsConnected()) return {};
std::vector<DiffuserPluginTensor> diffuser_tensors;
diffuser_tensors.reserve(kPlugInTensorsIn(cc)->size());
for (const auto& mp_tensor : *kPlugInTensorsIn(cc)) {
DiffuserPluginTensor diffuser_tensor;
diffuser_tensor.shape[0] = mp_tensor.shape().dims[0];
diffuser_tensor.shape[1] = mp_tensor.shape().dims[1];
diffuser_tensor.shape[2] = mp_tensor.shape().dims[2];
diffuser_tensor.shape[3] = mp_tensor.shape().dims[3];
diffuser_tensor.data = mp_tensor.GetCpuReadView().buffer<float>();
diffuser_tensors.push_back(diffuser_tensor);
}
return diffuser_tensors;
}
absl::Status LoadDiffuser() {
handle_ = dlopen("libimagegenerator_gpu.so", RTLD_NOW | RTLD_LOCAL);
RET_CHECK(handle_) << dlerror();
create_ptr_ = reinterpret_cast<DiffuserContext* (*)(const DiffuserConfig*)>(
dlsym(handle_, "DiffuserCreate"));
RET_CHECK(create_ptr_) << dlerror();
reset_ptr_ =
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int,
const void*)>(dlsym(handle_, "DiffuserReset"));
RET_CHECK(reset_ptr_) << dlerror();
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
dlsym(handle_, "DiffuserIterate"));
RET_CHECK(iterate_ptr_) << dlerror();
decode_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, uint8_t*)>(
dlsym(handle_, "DiffuserDecode"));
RET_CHECK(decode_ptr_) << dlerror();
delete_ptr_ = reinterpret_cast<void (*)(DiffuserContext*)>(
dlsym(handle_, "DiffuserDelete"));
RET_CHECK(delete_ptr_) << dlerror();
return absl::OkStatus();
}
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
return (*create_ptr_)(a);
}
bool DiffuserReset(const char* a, int b, int c,
const std::vector<DiffuserPluginTensor>* d) {
return (*reset_ptr_)(context_, a, b, c, d);
}
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
void DiffuserDelete() { (*delete_ptr_)(context_); }
void* handle_ = nullptr;
DiffuserContext* context_ = nullptr;
DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*);
int (*iterate_ptr_)(DiffuserContext*, int, int);
int (*decode_ptr_)(DiffuserContext*, uint8_t*);
void (*delete_ptr_)(DiffuserContext*);
int show_every_n_iteration_;
bool emit_empty_packet_;
};
absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
StableDiffusionIterateCalculatorOptions options;
if (kOptionsIn(cc).IsEmpty()) {
options = cc->Options<StableDiffusionIterateCalculatorOptions>();
} else {
options = kOptionsIn(cc).Get();
}
show_every_n_iteration_ = options.show_every_n_iteration();
emit_empty_packet_ = options.emit_empty_packet();
MP_RETURN_IF_ERROR(LoadDiffuser());
DiffuserConfig config;
config.model_type = ToDiffuserModelType(options.model_type());
if (options.file_folder().empty()) {
std::strcpy(config.model_dir, "bins/"); // NOLINT
} else {
std::strcpy(config.model_dir, options.file_folder().c_str()); // NOLINT
}
MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir))
<< config.model_dir;
RET_CHECK(options.lora_file_folder().empty() ||
options.lora_weights_layer_mapping().empty())
<< "Can't set both lora_file_folder and lora_weights_layer_mapping.";
std::strcpy(config.lora_dir, options.lora_file_folder().c_str()); // NOLINT
std::map<std::string, const char*> lora_weights_layer_mapping;
for (auto& layer_name_and_weights : options.lora_weights_layer_mapping()) {
lora_weights_layer_mapping[layer_name_and_weights.first] =
(char*)layer_name_and_weights.second;
}
config.lora_weights_layer_mapping = !lora_weights_layer_mapping.empty()
? &lora_weights_layer_mapping
: nullptr;
config.lora_rank = options.lora_rank();
config.seed = options.base_seed();
config.image_width = options.output_image_width();
config.image_height = options.output_image_height();
config.run_unet_with_plugins = kPlugInTensorsIn(cc).IsConnected();
config.env_options = {
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
.performance_hint = kDiffuserPerformanceHintHigh,
};
config.plugins_strength = options.plugins_strength();
RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f)
<< "The value of plugins_strength must be in the range of [0, 1].";
context_ = DiffuserCreate(&config);
RET_CHECK(context_);
return absl::OkStatus();
}
absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
const auto& options =
cc->Options().GetExtension(StableDiffusionIterateCalculatorOptions::ext);
const std::string& prompt = *kPromptIn(cc);
const int steps = *kStepsIn(cc);
const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
: options.base_seed();
if (kIterationIn(cc).IsEmpty()) {
const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
options.output_image_height());
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
kImageOut(cc).Send(std::move(image_out));
} else {
const int iteration = *kIterationIn(cc);
RET_CHECK_LT(iteration, steps);
// Extract text embedding on first iteration.
if (iteration == 0) {
const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(
DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
}
RET_CHECK(DiffuserIterate(steps, iteration));
// Decode the output and send out the image for visualization.
if ((iteration + 1) % show_every_n_iteration_ == 0 ||
iteration == steps - 1) {
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
options.output_image_height());
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
kImageOut(cc).Send(std::move(image_out));
} else if (emit_empty_packet_) {
kImageOut(cc).Send(Packet<mediapipe::ImageFrame>());
}
}
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(StableDiffusionIterateCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,84 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "StableDiffusionIterateCalculatorOptionsProto";
message StableDiffusionIterateCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional StableDiffusionIterateCalculatorOptions ext = 510855836;
}
// The random seed that is fed into the calculator to control the randomness
// of the generated image.
optional uint32 base_seed = 1 [default = 0];
// The target output image size. Must be a multiple of 8 and larger than 384.
optional int32 output_image_width = 2 [default = 512];
optional int32 output_image_height = 3 [default = 512];
// The folder name must end of '/'.
optional string file_folder = 4 [default = "bins/"];
// Note: only one of lora_file_folder and lora_weights_layer_mapping should be
// set.
// The LoRA file folder. The folder name must end of '/'.
optional string lora_file_folder = 9 [default = ""];
// The LoRA layer name mapping to the weight buffer position in the file.
map<string, uint64> lora_weights_layer_mapping = 10;
// The LoRA rank.
optional int32 lora_rank = 12 [default = 4];
// Determine when to run image decoding for how many every iterations.
// Setting this to 1 means we run the image decoding for every iteration for
// displaying the intermediate result, but it will also introduce much higher
// overall latency.
// Setting this to be the targeted number of iterations will only run the
// image decoding at the end, giving the best overall latency.
optional int32 show_every_n_iteration = 5 [default = 1];
// If set to be True, the calculator will perform a GPU-CPU sync and emit an
// empty packet. It is used to provide the signal of which iterations it is
// currently at, typically used to create a progress bar. Note that this also
// introduce overhead, but not significanly based on our experiments (~1ms).
optional bool emit_empty_packet = 6 [default = false];
enum ClPriorityHint {
PRIORITY_HINT_NORMAL = 0; // Default, must be first.
PRIORITY_HINT_LOW = 1;
PRIORITY_HINT_HIGH = 2;
}
// OpenCL priority hint. Set this to LOW to yield to other GPU contexts.
// This lowers inference speed, but helps keeping the UI responsive.
optional ClPriorityHint cl_priority_hint = 7;
enum ModelType {
DEFAULT = 0;
SD_1 = 1; // Stable Diffusion v1 models, including SD 1.4 and 1.5.
}
// Stable Diffusion model type. Default to Stable Diffusion v1.
optional ModelType model_type = 8 [default = SD_1];
// The strength of the diffusion plugins inputs.
optional float plugins_strength = 11 [default = 1.0];
}

View File

@ -0,0 +1,397 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/image_generator/image_generator.h"
#include <memory>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_generator {
namespace {
using ImageGeneratorGraphOptionsProto = ::mediapipe::tasks::vision::
image_generator::proto::ImageGeneratorGraphOptions;
using ConditionedImageGraphOptionsProto = ::mediapipe::tasks::vision::
image_generator::proto::ConditionedImageGraphOptions;
using ControlPluginGraphOptionsProto = ::mediapipe::tasks::vision::
image_generator::proto::ControlPluginGraphOptions;
using FaceLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
face_landmarker::proto::FaceLandmarkerGraphOptions;
constexpr absl::string_view kImageTag = "IMAGE";
constexpr absl::string_view kImageOutName = "image_out";
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
constexpr absl::string_view kConditionImageName = "condition_image";
constexpr absl::string_view kSourceConditionImageName =
"source_condition_image";
constexpr absl::string_view kStepsTag = "STEPS";
constexpr absl::string_view kStepsName = "steps";
constexpr absl::string_view kIterationTag = "ITERATION";
constexpr absl::string_view kIterationName = "iteration";
constexpr absl::string_view kPromptTag = "PROMPT";
constexpr absl::string_view kPromptName = "prompt";
constexpr absl::string_view kRandSeedTag = "RAND_SEED";
constexpr absl::string_view kRandSeedName = "rand_seed";
constexpr absl::string_view kSelectTag = "SELECT";
constexpr absl::string_view kSelectName = "select";
constexpr char kImageGeneratorGraphTypeName[] =
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
constexpr char kConditionedImageGraphContainerTypeName[] =
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph".
CalculatorGraphConfig CreateImageGeneratorGraphConfig(
std::unique_ptr<ImageGeneratorGraphOptionsProto> options,
bool use_condition_image) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kImageGeneratorGraphTypeName);
subgraph.GetOptions<ImageGeneratorGraphOptionsProto>().CopyFrom(*options);
graph.In(kStepsTag).SetName(kStepsName) >> subgraph.In(kStepsTag);
graph.In(kIterationTag).SetName(kIterationName) >> subgraph.In(kIterationTag);
graph.In(kPromptTag).SetName(kPromptName) >> subgraph.In(kPromptTag);
graph.In(kRandSeedTag).SetName(kRandSeedName) >> subgraph.In(kRandSeedTag);
if (use_condition_image) {
graph.In(kConditionImageTag).SetName(kConditionImageName) >>
subgraph.In(kConditionImageTag);
graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag);
}
subgraph.Out(kImageTag).SetName(kImageOutName) >>
graph[api2::Output<Image>::Optional(kImageTag)];
return graph.GetConfig();
}
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer".
CalculatorGraphConfig CreateConditionedImageGraphContainerConfig(
std::unique_ptr<ImageGeneratorGraphOptionsProto> options) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kConditionedImageGraphContainerTypeName);
subgraph.GetOptions<ImageGeneratorGraphOptionsProto>().CopyFrom(*options);
graph.In(kImageTag).SetName(kSourceConditionImageName) >>
subgraph.In(kImageTag);
graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag);
subgraph.Out(kConditionImageTag).SetName(kConditionImageName) >>
graph.Out(kConditionImageTag).Cast<Image>();
return graph.GetConfig();
}
absl::Status SetFaceConditionOptionsToProto(
FaceConditionOptions& face_condition_options,
ControlPluginGraphOptionsProto& options_proto) {
// Configure face plugin model.
auto plugin_base_options_proto =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(face_condition_options.base_options)));
options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get());
// Configure face landmarker graph.
auto& face_landmarker_options =
face_condition_options.face_landmarker_options;
auto& face_landmarker_options_proto =
*options_proto.mutable_conditioned_image_graph_options()
->mutable_face_condition_type_options()
->mutable_face_landmarker_graph_options();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(face_landmarker_options.base_options)));
face_landmarker_options_proto.mutable_base_options()->Swap(
base_options_proto.get());
face_landmarker_options_proto.mutable_base_options()->set_use_stream_mode(
false);
// Configure face detector options.
auto* face_detector_graph_options =
face_landmarker_options_proto.mutable_face_detector_graph_options();
face_detector_graph_options->set_num_faces(face_landmarker_options.num_faces);
face_detector_graph_options->set_min_detection_confidence(
face_landmarker_options.min_face_detection_confidence);
// Configure face landmark detector options.
face_landmarker_options_proto.set_min_tracking_confidence(
face_landmarker_options.min_tracking_confidence);
auto* face_landmarks_detector_graph_options =
face_landmarker_options_proto
.mutable_face_landmarks_detector_graph_options();
face_landmarks_detector_graph_options->set_min_detection_confidence(
face_landmarker_options.min_face_presence_confidence);
return absl::OkStatus();
}
absl::Status SetDepthConditionOptionsToProto(
DepthConditionOptions& depth_condition_options,
ControlPluginGraphOptionsProto& options_proto) {
// Configure face plugin model.
auto plugin_base_options_proto =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(depth_condition_options.base_options)));
options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get());
auto& image_segmenter_graph_options =
*options_proto.mutable_conditioned_image_graph_options()
->mutable_depth_condition_type_options()
->mutable_image_segmenter_graph_options();
auto depth_base_options_proto =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(depth_condition_options.image_segmenter_options.base_options)));
image_segmenter_graph_options.mutable_base_options()->Swap(
depth_base_options_proto.get());
image_segmenter_graph_options.mutable_base_options()->set_use_stream_mode(
false);
image_segmenter_graph_options.set_display_names_locale(
depth_condition_options.image_segmenter_options.display_names_locale);
return absl::OkStatus();
}
absl::Status SetEdgeConditionOptionsToProto(
EdgeConditionOptions& edge_condition_options,
ControlPluginGraphOptionsProto& options_proto) {
auto plugin_base_options_proto =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(edge_condition_options.base_options)));
options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get());
auto& edge_options_proto =
*options_proto.mutable_conditioned_image_graph_options()
->mutable_edge_condition_type_options();
edge_options_proto.set_threshold_1(edge_condition_options.threshold_1);
edge_options_proto.set_threshold_2(edge_condition_options.threshold_2);
edge_options_proto.set_aperture_size(edge_condition_options.aperture_size);
edge_options_proto.set_l2_gradient(edge_condition_options.l2_gradient);
return absl::OkStatus();
}
// Helper holder struct of image generator graph options and condition type
// index mapping.
struct ImageGeneratorOptionsProtoAndConditionTypeIndex {
std::unique_ptr<ImageGeneratorGraphOptionsProto> options_proto;
std::unique_ptr<std::map<ConditionOptions::ConditionType, int>>
condition_type_index;
};
// Converts the user-facing ImageGeneratorOptions struct to the internal
// ImageGeneratorOptions proto.
absl::StatusOr<ImageGeneratorOptionsProtoAndConditionTypeIndex>
ConvertImageGeneratorGraphOptionsProto(
ImageGeneratorOptions* image_generator_options,
ConditionOptions* condition_options) {
ImageGeneratorOptionsProtoAndConditionTypeIndex
options_proto_and_condition_index;
// Configure base image generator options.
options_proto_and_condition_index.options_proto =
std::make_unique<ImageGeneratorGraphOptionsProto>();
auto& options_proto = *options_proto_and_condition_index.options_proto;
options_proto.set_text2image_model_directory(
image_generator_options->text2image_model_directory);
if (image_generator_options->lora_weights_file_path.has_value()) {
options_proto.mutable_lora_weights_file()->set_file_name(
*image_generator_options->lora_weights_file_path);
}
// Configure optional condition type options.
if (condition_options != nullptr) {
options_proto_and_condition_index.condition_type_index =
std::make_unique<std::map<ConditionOptions::ConditionType, int>>();
auto& condition_type_index =
*options_proto_and_condition_index.condition_type_index;
if (condition_options->face_condition_options.has_value()) {
condition_type_index[ConditionOptions::FACE] =
condition_type_index.size();
auto& face_plugin_graph_options =
*options_proto.add_control_plugin_graphs_options();
RET_CHECK_OK(SetFaceConditionOptionsToProto(
*condition_options->face_condition_options,
face_plugin_graph_options));
}
if (condition_options->depth_condition_options.has_value()) {
condition_type_index[ConditionOptions::DEPTH] =
condition_type_index.size();
auto& depth_plugin_graph_options =
*options_proto.add_control_plugin_graphs_options();
RET_CHECK_OK(SetDepthConditionOptionsToProto(
*condition_options->depth_condition_options,
depth_plugin_graph_options));
}
if (condition_options->edge_condition_options.has_value()) {
condition_type_index[ConditionOptions::EDGE] =
condition_type_index.size();
auto& edge_plugin_graph_options =
*options_proto.add_control_plugin_graphs_options();
RET_CHECK_OK(SetEdgeConditionOptionsToProto(
*condition_options->edge_condition_options,
edge_plugin_graph_options));
}
if (condition_type_index.empty()) {
return absl::InvalidArgumentError(
"At least one condition type must be set.");
}
}
return options_proto_and_condition_index;
}
} // namespace
absl::StatusOr<std::unique_ptr<ImageGenerator>> ImageGenerator::Create(
std::unique_ptr<ImageGeneratorOptions> image_generator_options,
std::unique_ptr<ConditionOptions> condition_options) {
bool use_condition_image = condition_options != nullptr;
ASSIGN_OR_RETURN(auto options_proto_and_condition_index,
ConvertImageGeneratorGraphOptionsProto(
image_generator_options.get(), condition_options.get()));
std::unique_ptr<proto::ImageGeneratorGraphOptions>
options_proto_for_condition_image_graphs_container;
if (use_condition_image) {
options_proto_for_condition_image_graphs_container =
std::make_unique<proto::ImageGeneratorGraphOptions>();
options_proto_for_condition_image_graphs_container->CopyFrom(
*options_proto_and_condition_index.options_proto);
}
ASSIGN_OR_RETURN(
auto image_generator,
(core::VisionTaskApiFactory::Create<ImageGenerator,
ImageGeneratorGraphOptionsProto>(
CreateImageGeneratorGraphConfig(
std::move(options_proto_and_condition_index.options_proto),
use_condition_image),
std::make_unique<tasks::core::MediaPipeBuiltinOpResolver>(),
core::RunningMode::IMAGE,
/*result_callback=*/nullptr)));
image_generator->use_condition_image_ = use_condition_image;
if (use_condition_image) {
image_generator->condition_type_index_ =
std::move(options_proto_and_condition_index.condition_type_index);
ASSIGN_OR_RETURN(
image_generator->condition_image_graphs_container_task_runner_,
tasks::core::TaskRunner::Create(
CreateConditionedImageGraphContainerConfig(
std::move(options_proto_for_condition_image_graphs_container)),
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>()));
}
image_generator->init_timestamp_ = absl::Now();
return image_generator;
}
absl::StatusOr<Image> ImageGenerator::CreateConditionImage(
Image source_condition_image,
ConditionOptions::ConditionType condition_type) {
if (condition_type_index_->find(condition_type) ==
condition_type_index_->end()) {
return absl::InvalidArgumentError(
"The condition type is not created during initialization.");
}
ASSIGN_OR_RETURN(
auto output_packets,
condition_image_graphs_container_task_runner_->Process({
{std::string(kSourceConditionImageName),
MakePacket<Image>(std::move(source_condition_image))},
{std::string(kSelectName),
MakePacket<int>(condition_type_index_->at(condition_type))},
}));
return output_packets.at(std::string(kConditionImageName)).Get<Image>();
}
absl::StatusOr<ImageGeneratorResult> ImageGenerator::Generate(
const std::string& prompt, int iterations, int seed) {
if (use_condition_image_) {
return absl::InvalidArgumentError(
"ImageGenerator is created to use with conditioned image.");
}
return RunIterations(prompt, iterations, seed, std::nullopt);
}
absl::StatusOr<ImageGeneratorResult> ImageGenerator::Generate(
const std::string& prompt, Image condition_image,
ConditionOptions::ConditionType condition_type, int iterations, int seed) {
if (!use_condition_image_) {
return absl::InvalidArgumentError(
"ImageGenerator is created to use without conditioned image.");
}
ASSIGN_OR_RETURN(auto plugin_model_image,
CreateConditionImage(condition_image, condition_type));
return RunIterations(
prompt, iterations, seed,
ConditionInputs{plugin_model_image,
condition_type_index_->at(condition_type)});
}
absl::StatusOr<ImageGeneratorResult> ImageGenerator::RunIterations(
const std::string& prompt, int steps, int rand_seed,
std::optional<ConditionInputs> condition_inputs) {
tasks::core::PacketMap output_packets;
ImageGeneratorResult result;
auto timestamp = (absl::Now() - init_timestamp_) / absl::Milliseconds(1);
for (int i = 0; i < steps; ++i) {
tasks::core::PacketMap input_packets;
if (i == 0 && condition_inputs.has_value()) {
input_packets[std::string(kConditionImageName)] =
MakePacket<Image>(condition_inputs->condition_image)
.At(Timestamp(timestamp));
input_packets[std::string(kSelectName)] =
MakePacket<int>(condition_inputs->select).At(Timestamp(timestamp));
}
input_packets[std::string(kStepsName)] =
MakePacket<int>(steps).At(Timestamp(timestamp));
input_packets[std::string(kIterationName)] =
MakePacket<int>(i).At(Timestamp(timestamp));
input_packets[std::string(kPromptName)] =
MakePacket<std::string>(prompt).At(Timestamp(timestamp));
input_packets[std::string(kRandSeedName)] =
MakePacket<int>(rand_seed).At(Timestamp(timestamp));
ASSIGN_OR_RETURN(output_packets, ProcessImageData(input_packets));
timestamp += 1;
}
result.generated_image =
output_packets.at(std::string(kImageOutName)).Get<Image>();
if (condition_inputs.has_value()) {
result.condition_image = condition_inputs->condition_image;
}
return result;
}
} // namespace image_generator
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,157 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_
#include <memory>
#include <optional>
#include <string>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h"
#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_generator {
// Options for drawing face landmarks image.
struct FaceConditionOptions {
// The base options for plugin model.
tasks::core::BaseOptions base_options;
// Face landmarker options used to detect face landmarks in the condition
// image.
face_landmarker::FaceLandmarkerOptions face_landmarker_options;
};
// Options for detecting edges image.
struct EdgeConditionOptions {
// The base options for plugin model.
tasks::core::BaseOptions base_options;
// These parameters are used to config Canny edge algorithm of OpenCV.
// See more details:
// https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
// First threshold for the hysteresis procedure.
float threshold_1 = 100;
// Second threshold for the hysteresis procedure.
float threshold_2 = 200;
// Aperture size for the Sobel operator. Typical range is 3~7.
int aperture_size = 3;
// A flag, indicating whether a more accurate L2 norm should be used to
// calculate the image gradient magnitude ( L2gradient=true ), or whether
// the default L1 norm is enough ( L2gradient=false ).
bool l2_gradient = false;
};
// Options for detecting depth image.
struct DepthConditionOptions {
// The base options for plugin model.
tasks::core::BaseOptions base_options;
// Image segmenter options used to detect depth in the condition image.
image_segmenter::ImageSegmenterOptions image_segmenter_options;
};
struct ConditionOptions {
enum ConditionType { FACE, EDGE, DEPTH };
std::optional<FaceConditionOptions> face_condition_options;
std::optional<EdgeConditionOptions> edge_condition_options;
std::optional<DepthConditionOptions> depth_condition_options;
};
// Note: The API is experimental and subject to change.
// The options for configuring a mediapipe image generator task.
struct ImageGeneratorOptions {
// The text to image model directory storing the model weights.
std::string text2image_model_directory;
// The path to LoRA weights file.
std::optional<std::string> lora_weights_file_path;
};
class ImageGenerator : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates an ImageGenerator from the provided options.
// image_generator_options: options to create the image generator.
// condition_options: optional options if plugin models are used to generate
// an image based on the condition image.
static absl::StatusOr<std::unique_ptr<ImageGenerator>> Create(
std::unique_ptr<ImageGeneratorOptions> image_generator_options,
std::unique_ptr<ConditionOptions> condition_options = nullptr);
// Create the condition image of specified condition type from the source
// condition image. Currently support face landmarks, depth image and edge
// image as the condition image.
absl::StatusOr<Image> CreateConditionImage(
Image source_condition_image,
ConditionOptions::ConditionType condition_type);
// Generates an image for iterations and the given random seed. Only valid
// when the ImageGenerator is created without condition options.
absl::StatusOr<ImageGeneratorResult> Generate(const std::string& prompt,
int iterations, int seed = 0);
// Generates an image based on the condition image for iterations and the
// given random seed.
// A detailed introduction to the condition image:
// https://ai.googleblog.com/2023/06/on-device-diffusion-plugins-for.html
absl::StatusOr<ImageGeneratorResult> Generate(
const std::string& prompt, Image condition_image,
ConditionOptions::ConditionType condition_type, int iterations,
int seed = 0);
private:
struct ConditionInputs {
Image condition_image;
int select;
};
bool use_condition_image_ = false;
absl::Time init_timestamp_;
std::unique_ptr<tasks::core::TaskRunner>
condition_image_graphs_container_task_runner_;
std::unique_ptr<std::map<ConditionOptions::ConditionType, int>>
condition_type_index_;
absl::StatusOr<ImageGeneratorResult> RunIterations(
const std::string& prompt, int steps, int rand_seed,
std::optional<ConditionInputs> condition_inputs);
};
} // namespace image_generator
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_

View File

@ -0,0 +1,361 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/switch_container.pb.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_generator {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
constexpr int kPluginsOutputSize = 512;
constexpr absl::string_view kTensorsTag = "TENSORS";
constexpr absl::string_view kImageTag = "IMAGE";
constexpr absl::string_view kImageCpuTag = "IMAGE_CPU";
constexpr absl::string_view kStepsTag = "STEPS";
constexpr absl::string_view kIterationTag = "ITERATION";
constexpr absl::string_view kPromptTag = "PROMPT";
constexpr absl::string_view kRandSeedTag = "RAND_SEED";
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
constexpr absl::string_view kSelectTag = "SELECT";
constexpr absl::string_view kMetadataFilename = "metadata";
constexpr absl::string_view kLoraRankStr = "lora_rank";
struct ImageGeneratorInputs {
Source<std::string> prompt;
Source<int> steps;
Source<int> iteration;
Source<int> rand_seed;
std::optional<Source<Image>> condition_image;
std::optional<Source<int>> select_condition_type;
};
struct ImageGeneratorOutputs {
Source<Image> generated_image;
};
} // namespace
// A container graph containing several ConditionedImageGraph from which to
// choose specified condition type.
// Inputs:
// IMAGE - Image
// The source condition image, used to generate the condition image.
// SELECT - int
// The index of the selected conditioned image graph.
// Outputs:
// CONDITION_IMAGE - Image
// The condition image created from the specified condition type.
class ConditionedImageGraphContainer : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
auto& graph_options =
*sc->MutableOptions<proto::ImageGeneratorGraphOptions>();
auto source_condition_image = graph.In(kImageTag).Cast<Image>();
auto select_condition_type = graph.In(kSelectTag).Cast<int>();
auto& switch_container = graph.AddNode("SwitchContainer");
auto& switch_options =
switch_container.GetOptions<mediapipe::SwitchContainerOptions>();
for (auto& control_plugin_graph_options :
*graph_options.mutable_control_plugin_graphs_options()) {
auto& node = *switch_options.add_contained_node();
node.set_calculator(
"mediapipe.tasks.vision.image_generator.ConditionedImageGraph");
node.mutable_node_options()->Add()->PackFrom(
control_plugin_graph_options.conditioned_image_graph_options());
}
source_condition_image >> switch_container.In(kImageTag);
select_condition_type >> switch_container.In(kSelectTag);
auto condition_image = switch_container.Out(kImageTag).Cast<Image>();
condition_image >> graph.Out(kConditionImageTag);
return graph.GetConfig();
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_generator::ConditionedImageGraphContainer); // NOLINT
// clang-format on
// A helper graph to convert condition image to Tensor using the control plugin
// model.
// Inputs:
// CONDITION_IMAGE - Image
// The condition image input to the control plugin model.
// Outputs:
// PLUGIN_TENSORS - std::vector<Tensor>
// The output tensors from the control plugin model. The tensors are used as
// inputs to the image generation model.
class ControlPluginGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
auto& graph_options =
*sc->MutableOptions<proto::ControlPluginGraphOptions>();
auto condition_image = graph.In(kConditionImageTag).Cast<Image>();
// Convert Image to ImageFrame.
auto& from_image = graph.AddNode("FromImageCalculator");
condition_image >> from_image.In(kImageTag);
auto image_frame = from_image.Out(kImageCpuTag);
// Convert ImageFrame to Tensor.
auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator");
auto& image_to_tensor_options =
image_to_tensor.GetOptions<mediapipe::ImageToTensorCalculatorOptions>();
image_to_tensor_options.set_output_tensor_width(kPluginsOutputSize);
image_to_tensor_options.set_output_tensor_height(kPluginsOutputSize);
image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1);
image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1);
image_to_tensor_options.set_keep_aspect_ratio(true);
image_frame >> image_to_tensor.In(kImageTag);
// Create the plugin model resource.
ASSIGN_OR_RETURN(
const core::ModelResources* plugin_model_resources,
CreateModelResources(
sc,
std::make_unique<tasks::core::proto::ExternalFile>(
*graph_options.mutable_base_options()->mutable_model_asset())));
// Add control plugin model inference.
auto& plugins_inference =
AddInference(*plugin_model_resources,
graph_options.base_options().acceleration(), graph);
image_to_tensor.Out(kTensorsTag) >> plugins_inference.In(kTensorsTag);
// The plugins model is not runnable on OpenGL. Error message:
// TfLiteGpuDelegate Prepare: Batch size mismatch, expected 1 but got 64
// Node number 67 (TfLiteGpuDelegate) failed to prepare.
plugins_inference.GetOptions<mediapipe::InferenceCalculatorOptions>()
.mutable_delegate()
->mutable_xnnpack();
plugins_inference.Out(kTensorsTag).Cast<std::vector<Tensor>>() >>
graph.Out(kPluginTensorsTag);
return graph.GetConfig();
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_generator::ControlPluginGraph);
// A "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph" performs image
// generation from a text prompt, and a optional condition image.
//
// Inputs:
// PROMPT - std::string
// The prompt describing the image to be generated.
// STEPS - int
// The total steps to generate the image.
// ITERATION - int
// The current iteration in the generating steps. Must be less than STEPS.
// RAND_SEED - int
// The randaom seed input to the image generation model.
// CONDITION_IMAGE - Image
// The condition image used as a guidance for the image generation. Only
// valid, if condtrol plugin graph options are set in the graph options.
// SELECT - int
// The index of the selected the control plugin graph.
//
// Outputs:
// IMAGE - Image
// The generated image.
// STEPS - int @optional
// The total steps to generate the image. The same as STEPS input.
// ITERATION - int @optional
// The current iteration in the generating steps. The same as ITERATION
// input.
class ImageGeneratorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
auto* subgraph_options =
sc->MutableOptions<proto::ImageGeneratorGraphOptions>();
std::optional<const core::ModelAssetBundleResources*> lora_resources;
// Create LoRA weights asset bundle resources.
if (subgraph_options->has_lora_weights_file()) {
auto external_file = std::make_unique<tasks::core::proto::ExternalFile>();
external_file->Swap(subgraph_options->mutable_lora_weights_file());
ASSIGN_OR_RETURN(lora_resources, CreateModelAssetBundleResources(
sc, std::move(external_file)));
}
std::optional<Source<Image>> condition_image;
std::optional<Source<int>> select_condition_type;
if (!subgraph_options->control_plugin_graphs_options().empty()) {
condition_image = graph.In(kConditionImageTag).Cast<Image>();
select_condition_type = graph.In(kSelectTag).Cast<int>();
}
ASSIGN_OR_RETURN(
auto outputs,
BuildImageGeneratorGraph(
*sc->MutableOptions<proto::ImageGeneratorGraphOptions>(),
lora_resources,
ImageGeneratorInputs{
/*prompt=*/graph.In(kPromptTag).Cast<std::string>(),
/*steps=*/graph.In(kStepsTag).Cast<int>(),
/*iteration=*/graph.In(kIterationTag).Cast<int>(),
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
/*condition_image*/ condition_image,
/*select_condition_type*/ select_condition_type,
},
graph));
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
// Optional outputs to provide the current iteration.
auto& pass_through = graph.AddNode("PassThroughCalculator");
graph.In(kIterationTag) >> pass_through.In(0);
graph.In(kStepsTag) >> pass_through.In(1);
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
return graph.GetConfig();
}
absl::StatusOr<ImageGeneratorOutputs> BuildImageGeneratorGraph(
proto::ImageGeneratorGraphOptions& subgraph_options,
std::optional<const core::ModelAssetBundleResources*> lora_resources,
ImageGeneratorInputs inputs, Graph& graph) {
auto& stable_diff = graph.AddNode("StableDiffusionIterateCalculator");
if (inputs.condition_image.has_value()) {
// Add switch container for multiple control plugin graphs.
auto& switch_container = graph.AddNode("SwitchContainer");
auto& switch_options =
switch_container.GetOptions<mediapipe::SwitchContainerOptions>();
for (auto& control_plugin_graph_options :
*subgraph_options.mutable_control_plugin_graphs_options()) {
auto& node = *switch_options.add_contained_node();
node.set_calculator(
"mediapipe.tasks.vision.image_generator.ControlPluginGraph");
node.mutable_node_options()->Add()->PackFrom(
control_plugin_graph_options);
}
*inputs.condition_image >> switch_container.In(kConditionImageTag);
*inputs.select_condition_type >> switch_container.In(kSelectTag);
auto plugin_tensors = switch_container.Out(kPluginTensorsTag);
// Additional diffusion plugins calculator to pass tensors to diffusion
// iterator.
auto& plugins_output = graph.AddNode("DiffusionPluginsOutputCalculator");
plugin_tensors >> plugins_output.In(kTensorsTag);
inputs.steps >> plugins_output.In(kStepsTag);
inputs.iteration >> plugins_output.In(kIterationTag);
plugins_output.Out(kTensorsTag) >> stable_diff.In(kPluginTensorsTag);
}
inputs.prompt >> stable_diff.In(kPromptTag);
inputs.steps >> stable_diff.In(kStepsTag);
inputs.iteration >> stable_diff.In(kIterationTag);
inputs.rand_seed >> stable_diff.In(kRandSeedTag);
mediapipe::StableDiffusionIterateCalculatorOptions& options =
stable_diff
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
options.set_base_seed(0);
options.set_output_image_height(kPluginsOutputSize);
options.set_output_image_width(kPluginsOutputSize);
options.set_file_folder(subgraph_options.text2image_model_directory());
options.set_show_every_n_iteration(100);
options.set_emit_empty_packet(true);
if (lora_resources.has_value()) {
auto& lora_layer_weights_mapping =
*options.mutable_lora_weights_layer_mapping();
for (const auto& file_path : (*lora_resources)->ListFiles()) {
auto basename = file::Basename(file_path);
ASSIGN_OR_RETURN(auto file_content,
(*lora_resources)->GetFile(std::string(file_path)));
if (file_path == kMetadataFilename) {
MP_RETURN_IF_ERROR(
ParseLoraMetadataAndConfigOptions(file_content, options));
} else {
lora_layer_weights_mapping[basename] =
reinterpret_cast<uint64_t>(file_content.data());
}
}
}
auto& to_image = graph.AddNode("ToImageCalculator");
stable_diff.Out(kImageTag) >> to_image.In(kImageCpuTag);
return {{to_image.Out(kImageTag).Cast<Image>()}};
}
private:
absl::Status ParseLoraMetadataAndConfigOptions(
absl::string_view contents,
mediapipe::StableDiffusionIterateCalculatorOptions& options) {
std::vector<absl::string_view> lines =
absl::StrSplit(contents, '\n', absl::SkipEmpty());
for (const auto& line : lines) {
std::vector<absl::string_view> values = absl::StrSplit(line, ',');
if (values[0] == kLoraRankStr) {
int lora_rank;
if (values.size() != 2 || !absl::SimpleAtoi(values[1], &lora_rank)) {
return absl::InvalidArgumentError(
absl::StrCat("Error parsing LoRA weights metadata. ", line));
}
options.set_lora_rank(lora_rank);
}
}
return absl::OkStatus();
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_generator::ImageGeneratorGraph);
} // namespace image_generator
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,41 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_
#include "mediapipe/framework/formats/image.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_generator {
// The result of ImageGenerator task.
struct ImageGeneratorResult {
// The generated image.
Image generated_image;
// The condition_image used in the plugin model, only available if the
// condition type is set in ImageGeneratorOptions.
std::optional<Image> condition_image = std::nullopt;
};
} // namespace image_generator
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_

View File

@ -0,0 +1,52 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "conditioned_image_graph_options_proto",
srcs = ["conditioned_image_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_proto",
],
)
mediapipe_proto_library(
name = "control_plugin_graph_options_proto",
srcs = ["control_plugin_graph_options.proto"],
deps = [
":conditioned_image_graph_options_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
mediapipe_proto_library(
name = "image_generator_graph_options_proto",
srcs = ["image_generator_graph_options.proto"],
deps = [
":control_plugin_graph_options_proto",
"//mediapipe/tasks/cc/core/proto:external_file_proto",
],
)

View File

@ -0,0 +1,66 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.image_generator.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto";
import "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
option java_outer_classname = "ConditionedImageGraphOptionsProto";
message ConditionedImageGraphOptions {
// For conditioned image graph based on face landmarks.
message FaceConditionTypeOptions {
// Options for the face landmarker used in the face landmarks type graph.
face_landmarker.proto.FaceLandmarkerGraphOptions
face_landmarker_graph_options = 1;
}
// For conditioned image graph base on edges detection.
message EdgeConditionTypeOptions {
// These parameters are used to config Canny edge algorithm of OpenCV.
// See more details:
// https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
// First threshold for the hysteresis procedure.
float threshold_1 = 1;
// Second threshold for the hysteresis procedure.
float threshold_2 = 2;
// Aperture size for the Sobel operator. Typical range is 3~7.
int32 aperture_size = 3;
// A flag, indicating whether a more accurate L2 norm should be used to
// calculate the image gradient magnitude ( L2gradient=true ), or whether
// the default L1 norm is enough ( L2gradient=false ).
bool l2_gradient = 4;
}
// For conditioned image graph base on depth map.
message DepthConditionTypeOptions {
// Options for the image segmenter used in the depth condition type graph.
image_segmenter.proto.ImageSegmenterGraphOptions
image_segmenter_graph_options = 1;
}
// The options for configuring the conditioned image graph.
oneof condition_type_options {
FaceConditionTypeOptions face_condition_type_options = 2;
EdgeConditionTypeOptions edge_condition_type_options = 3;
DepthConditionTypeOptions depth_condition_type_options = 4;
}
}

View File

@ -0,0 +1,34 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.image_generator.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
option java_outer_classname = "ControlPluginGraphOptionsProto";
message ControlPluginGraphOptions {
// The base options for the control plugin model.
core.proto.BaseOptions base_options = 1;
// The options for the ConditionedImageGraphOptions to generate control plugin
// model input image.
proto.ConditionedImageGraphOptions conditioned_image_graph_options = 2;
}

View File

@ -0,0 +1,35 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.vision.image_generator.proto;
import "mediapipe/tasks/cc/core/proto/external_file.proto";
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
option java_outer_classname = "ImageGeneratorGraphOptionsProto";
message ImageGeneratorGraphOptions {
// The directory containing the models weight of the text to image model.
string text2image_model_directory = 1;
// An optional LoRA weights file. If set, the diffusion model will be created
// with LoRA weights.
core.proto.ExternalFile lora_weights_file = 2;
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
}

View File

@ -59,6 +59,22 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite",
] ]
_VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_geometry/proto:mesh_3d_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
]
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
@ -249,6 +265,39 @@ EOF
native_library = native_library, native_library = native_library,
) )
def mediapipe_tasks_vision_image_generator_aar(name, srcs, native_library):
"""Builds medaipipe tasks vision image generator AAR.
Args:
name: The bazel target name.
srcs: MediaPipe Vision Tasks' source files.
native_library: The native library that contains image generator task's graph and calculators.
"""
native.genrule(
name = name + "tasks_manifest_generator",
outs = ["AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.imagegenerator">
<uses-sdk
android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>
EOF
""",
)
_mediapipe_tasks_aar(
name = name,
srcs = srcs,
manifest = "AndroidManifest.xml",
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS,
native_library = native_library,
)
def mediapipe_tasks_text_aar(name, srcs, native_library): def mediapipe_tasks_text_aar(name, srcs, native_library):
"""Builds medaipipe tasks text AAR. """Builds medaipipe tasks text AAR.
@ -344,6 +393,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l
"//third_party:androidx_annotation", "//third_party:androidx_annotation",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
"@com_google_protobuf//:protobuf_javalite",
] + select({ ] + select({
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"], "//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
"//mediapipe/framework/port:disable_opencv": [], "//mediapipe/framework/port:disable_opencv": [],

View File

@ -413,6 +413,9 @@ load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl"
mediapipe_tasks_vision_aar( mediapipe_tasks_vision_aar(
name = "tasks_vision", name = "tasks_vision",
srcs = glob(["**/*.java"]), srcs = glob(
["**/*.java"],
exclude = ["imagegenerator/**"],
),
native_library = ":libmediapipe_tasks_vision_jni_lib", native_library = ":libmediapipe_tasks_vision_jni_lib",
) )

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.imagegenerator">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,84 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
# The native library of MediaPipe vision image generator tasks.
cc_binary(
name = "libmediapipe_tasks_vision_image_generator_jni.so",
linkopts = [
"-Wl,--no-undefined",
"-Wl,--version-script,$(location //mediapipe/tasks/java:version_script.lds)",
],
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/image_generator:image_generator_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/java:version_script.lds",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_vision_image_generator_jni_lib",
srcs = [":libmediapipe_tasks_vision_image_generator_jni.so"],
alwayslink = 1,
)
android_library(
name = "imagegenerator",
srcs = [
"ImageGenerator.java",
"ImageGeneratorResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core_java",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:facelandmarker",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:imagesegmenter",
"//third_party:any_java_proto",
"//third_party:autovalue",
"//third_party/java/protobuf:protobuf_lite",
"@maven//:androidx_annotation_annotation",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_image_generator_aar")
mediapipe_tasks_vision_image_generator_aar(
name = "tasks_vision_image_generator",
srcs = glob(["**/*.java"]),
native_library = ":libmediapipe_tasks_vision_image_generator_jni_lib",
)

View File

@ -0,0 +1,660 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imagegenerator;
import android.content.Context;
import android.graphics.Bitmap;
import android.util.Log;
import androidx.annotation.Nullable;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.facelandmarker.FaceLandmarker.FaceLandmarkerOptions;
import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions;
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ConditionedImageGraphOptionsProto.ConditionedImageGraphOptions;
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ControlPluginGraphOptionsProto;
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ImageGeneratorGraphOptionsProto;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions;
import com.google.protobuf.Any;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/** Performs image generation from a text prompt. */
public final class ImageGenerator extends BaseVisionTaskApi {
private static final String STEPS_STREAM_NAME = "steps";
private static final String ITERATION_STREAM_NAME = "iteration";
private static final String PROMPT_STREAM_NAME = "prompt";
private static final String RAND_SEED_STREAM_NAME = "rand_seed";
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
private static final String SELECT_STREAM_NAME = "select";
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
private static final int STEPS_OUT_STREAM_INDEX = 1;
private static final int ITERATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
private static final String TAG = "ImageGenerator";
private TaskRunner conditionImageGraphsContainerTaskRunner;
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
private boolean useConditionImage = false;
/**
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
*
* @param context an Android {@link Context}.
* @param generatorOptions an {@link ImageGeneratorOptions} instance.
* @throws MediaPipeException if there is an error during {@link ImageGenerator} creation.
*/
public static ImageGenerator createFromOptions(
Context context, ImageGeneratorOptions generatorOptions) {
return createFromOptions(context, generatorOptions, null);
}
/**
* Creates an {@link ImageGenerator} instance, from {@link ImageGeneratorOptions} and {@link
* ConditionOptions}, if plugin models are used to generate an image based on the condition image.
*
* @param context an Android {@link Context}.
* @param generatorOptions an {@link ImageGeneratorOptions} instance.
* @param conditionOptions an {@link ConditionOptions} instance.
* @throws MediaPipeException if there is an error during {@link ImageGenerator} creation.
*/
public static ImageGenerator createFromOptions(
Context context,
ImageGeneratorOptions generatorOptions,
@Nullable ConditionOptions conditionOptions) {
List<String> inputStreams = new ArrayList<>();
inputStreams.addAll(
Arrays.asList(
"STEPS:" + STEPS_STREAM_NAME,
"ITERATION:" + ITERATION_STREAM_NAME,
"PROMPT:" + PROMPT_STREAM_NAME,
"RAND_SEED:" + RAND_SEED_STREAM_NAME));
final boolean useConditionImage = conditionOptions != null;
if (useConditionImage) {
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
inputStreams.add("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME);
generatorOptions.conditionOptions = Optional.of(conditionOptions);
}
List<String> outputStreams =
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out");
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ImageGeneratorResult, Void>() {
@Override
@Nullable
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps);
if (iteration != steps - 1) {
return null;
}
Log.i("ImageGenerator", "processing generated image");
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
return ImageGeneratorResult.create(
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
}
@Override
public Void convertToTaskInput(List<Packet> packets) {
return null;
}
});
handler.setHandleTimestampBoundChanges(true);
if (generatorOptions.resultListener().isPresent()) {
ResultListener<ImageGeneratorResult, Void> resultListener =
new ResultListener<ImageGeneratorResult, Void>() {
@Override
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
generatorOptions.resultListener().get().run(imageGeneratorResult);
}
};
handler.setResultListener(resultListener);
}
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<ImageGeneratorOptions>builder()
.setTaskName(ImageGenerator.class.getSimpleName())
.setTaskRunningModeName(RunningMode.IMAGE.name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(inputStreams)
.setOutputStreams(outputStreams)
.setTaskOptions(generatorOptions)
.setEnableFlowLimiting(false)
.build(),
handler);
ImageGenerator imageGenerator = new ImageGenerator(runner);
if (useConditionImage) {
imageGenerator.useConditionImage = true;
inputStreams =
Arrays.asList(
"IMAGE:" + SOURCE_CONDITION_IMAGE_STREAM_NAME, "SELECT:" + SELECT_STREAM_NAME);
outputStreams = Arrays.asList("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME);
OutputHandler<ConditionImageResult, Void> conditionImageHandler = new OutputHandler<>();
conditionImageHandler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ConditionImageResult, Void>() {
@Override
public ConditionImageResult convertToTaskResult(List<Packet> packets) {
Packet packet = packets.get(0);
return new AutoValue_ImageGenerator_ConditionImageResult(
new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(packet)).build(),
packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
}
@Override
public Void convertToTaskInput(List<Packet> packets) {
return null;
}
});
conditionImageHandler.setHandleTimestampBoundChanges(true);
imageGenerator.conditionImageGraphsContainerTaskRunner =
TaskRunner.create(
context,
TaskInfo.<ImageGeneratorOptions>builder()
.setTaskName(ImageGenerator.class.getSimpleName())
.setTaskRunningModeName(RunningMode.IMAGE.name())
.setTaskGraphName(CONDITION_IMAGE_GRAPHS_CONTAINER_NAME)
.setInputStreams(inputStreams)
.setOutputStreams(outputStreams)
.setTaskOptions(generatorOptions)
.setEnableFlowLimiting(false)
.build(),
conditionImageHandler);
imageGenerator.conditionTypeIndex = new HashMap<>();
if (conditionOptions.faceConditionOptions().isPresent()) {
imageGenerator.conditionTypeIndex.put(
ConditionOptions.ConditionType.FACE, imageGenerator.conditionTypeIndex.size());
}
if (conditionOptions.edgeConditionOptions().isPresent()) {
imageGenerator.conditionTypeIndex.put(
ConditionOptions.ConditionType.EDGE, imageGenerator.conditionTypeIndex.size());
}
if (conditionOptions.depthConditionOptions().isPresent()) {
imageGenerator.conditionTypeIndex.put(
ConditionOptions.ConditionType.DEPTH, imageGenerator.conditionTypeIndex.size());
}
}
return imageGenerator;
}
private ImageGenerator(TaskRunner taskRunner) {
super(taskRunner, RunningMode.IMAGE, "", "");
}
/**
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
* is created without condition options.
*
* @param prompt The text prompt describing the image to be generated.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
return runIterations(prompt, iterations, seed, null, 0);
}
/**
* Generates an image based on the source image for iterations and the given random seed. Only
* valid when the ImageGenerator is created with condition options.
*
* @param prompt The text prompt describing the image to be generated.
* @param sourceConditionImage The source image used to create the condition image, which is used
* as a guidance for the image generation.
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
* condition image.
* @param iterations The total iterations to generate the image.
* @param seed The random seed used during image generation.
*/
public ImageGeneratorResult generate(
String prompt,
MPImage sourceConditionImage,
ConditionOptions.ConditionType conditionType,
int iterations,
int seed) {
return runIterations(
prompt,
iterations,
seed,
createConditionImage(sourceConditionImage, conditionType),
conditionTypeIndex.get(conditionType));
}
/**
* Create the condition image of specified condition type from the source image. Currently support
* face landmarks, depth image and edge image as the condition image.
*
* @param sourceConditionImage The source image used to create the condition image.
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
* condition image.
*/
public MPImage createConditionImage(
MPImage sourceConditionImage, ConditionOptions.ConditionType conditionType) {
if (!conditionTypeIndex.containsKey(conditionType)) {
throw new IllegalArgumentException(
"The condition type " + conditionType.name() + " is not created during initialization.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(
SOURCE_CONDITION_IMAGE_STREAM_NAME,
conditionImageGraphsContainerTaskRunner
.getPacketCreator()
.createImage(sourceConditionImage));
inputPackets.put(
SELECT_STREAM_NAME,
conditionImageGraphsContainerTaskRunner
.getPacketCreator()
.createInt32(conditionTypeIndex.get(conditionType)));
ConditionImageResult result =
(ConditionImageResult) conditionImageGraphsContainerTaskRunner.process(inputPackets);
return result.conditionImage();
}
private ImageGeneratorResult runIterations(
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
ImageGeneratorResult result = null;
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
for (int i = 0; i < steps; i++) {
Map<String, Packet> inputPackets = new HashMap<>();
if (i == 0 && useConditionImage) {
inputPackets.put(
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
}
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(steps));
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(i));
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
result = (ImageGeneratorResult) runner.process(inputPackets, timestamp++);
}
if (useConditionImage) {
// Add condition image to the ImageGeneratorResult.
return ImageGeneratorResult.create(
result.generatedImage(), conditionImage, result.timestampMs());
}
return result;
}
/** Closes and cleans up the task runners. */
@Override
public void close() {
runner.close();
conditionImageGraphsContainerTaskRunner.close();
}
/** A container class for the condition image. */
@AutoValue
protected abstract static class ConditionImageResult implements TaskResult {
public abstract MPImage conditionImage();
@Override
public abstract long timestampMs();
}
/** Options for setting up an {@link ImageGenerator}. */
@AutoValue
public abstract static class ImageGeneratorOptions extends TaskOptions {
/** Builder for {@link ImageGeneratorOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the text to image model directory storing the model weights. */
public abstract Builder setText2ImageModelDirectory(String modelDirectory);
/** Sets the path to LoRA weights file. */
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
public abstract Builder setResultListener(
PureResultListener<ImageGeneratorResult> resultListener);
/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);
abstract ImageGeneratorOptions autoBuild();
/** Validates and builds the {@link ImageGeneratorOptions} instance. */
public final ImageGeneratorOptions build() {
return autoBuild();
}
}
abstract String text2ImageModelDirectory();
abstract Optional<String> loraWeightsFilePath();
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
abstract Optional<ErrorListener> errorListener();
private Optional<ConditionOptions> conditionOptions;
public static Builder builder() {
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
.setText2ImageModelDirectory("");
}
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
@Override
public Any convertToAnyProto() {
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder =
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder();
if (conditionOptions != null && conditionOptions.isPresent()) {
try {
taskOptionsBuilder.mergeFrom(
conditionOptions.get().convertToAnyProto().getValue(),
ExtensionRegistryLite.getGeneratedRegistry());
} catch (InvalidProtocolBufferException e) {
Log.e(TAG, "Error converting ConditionOptions to proto. " + e.getMessage());
e.printStackTrace();
}
}
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory());
if (loraWeightsFilePath().isPresent()) {
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
ExternalFileProto.ExternalFile.newBuilder();
externalFileBuilder.setFileName(loraWeightsFilePath().get());
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
}
return Any.newBuilder()
.setTypeUrl(
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
.setValue(taskOptionsBuilder.build().toByteString())
.build();
}
}
/** Options for setting up the conditions types and the plugin models */
@AutoValue
public abstract static class ConditionOptions extends TaskOptions {
/** The supported condition type. */
public enum ConditionType {
FACE,
EDGE,
DEPTH
}
/** Builder for {@link ConditionOptions}. At least one type of condition options must be set. */
@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setFaceConditionOptions(FaceConditionOptions faceConditionOptions);
public abstract Builder setDepthConditionOptions(DepthConditionOptions depthConditionOptions);
public abstract Builder setEdgeConditionOptions(EdgeConditionOptions edgeConditionOptions);
abstract ConditionOptions autoBuild();
/** Validates and builds the {@link ConditionOptions} instance. */
public final ConditionOptions build() {
ConditionOptions options = autoBuild();
if (!options.faceConditionOptions().isPresent()
&& !options.depthConditionOptions().isPresent()
&& !options.edgeConditionOptions().isPresent()) {
throw new IllegalArgumentException(
"At least one of `faceConditionOptions`, `depthConditionOptions` and"
+ " `edgeConditionOptions` must be set.");
}
return options;
}
}
abstract Optional<FaceConditionOptions> faceConditionOptions();
abstract Optional<DepthConditionOptions> depthConditionOptions();
abstract Optional<EdgeConditionOptions> edgeConditionOptions();
public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions.Builder();
}
/**
* Converts an {@link ImageGeneratorOptions} to a {@link CalculatorOptions} protobuf message.
*/
@Override
public Any convertToAnyProto() {
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder =
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder();
if (faceConditionOptions().isPresent()) {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
.build())
.build());
}
if (edgeConditionOptions().isPresent()) {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
.build())
.build());
if (depthConditionOptions().isPresent()) {
taskOptionsBuilder.addControlPluginGraphsOptions(
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
.setBaseOptions(
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions()))
.setConditionedImageGraphOptions(
ConditionedImageGraphOptions.newBuilder()
.setDepthConditionTypeOptions(
depthConditionOptions().get().convertToProto())
.build())
.build());
}
}
return Any.newBuilder()
.setTypeUrl(
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
.setValue(taskOptionsBuilder.build().toByteString())
.build();
}
/** Options for drawing face landmarks image. */
@AutoValue
public abstract static class FaceConditionOptions extends TaskOptions {
/** Builder for {@link FaceConditionOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */
public abstract Builder setFaceLandmarkerOptions(
FaceLandmarkerOptions faceLandmarkerOptions);
abstract FaceConditionOptions autoBuild();
/** Validates and builds the {@link FaceConditionOptions} instance. */
public final FaceConditionOptions build() {
return autoBuild();
}
}
abstract BaseOptions baseOptions();
abstract FaceLandmarkerOptions faceLandmarkerOptions();
public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder();
}
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
.setFaceLandmarkerGraphOptions(
FaceLandmarkerGraphOptions.newBuilder()
.mergeFrom(
faceLandmarkerOptions()
.convertToCalculatorOptionsProto()
.getExtension(FaceLandmarkerGraphOptions.ext))
.build())
.build();
}
}
/** Options for detecting depth image. */
@AutoValue
public abstract static class DepthConditionOptions extends TaskOptions {
/** Builder for {@link DepthConditionOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */
public abstract Builder setImageSegmenterOptions(
ImageSegmenterOptions imageSegmenterOptions);
abstract DepthConditionOptions autoBuild();
/** Validates and builds the {@link DepthConditionOptions} instance. */
public final DepthConditionOptions build() {
DepthConditionOptions options = autoBuild();
return options;
}
}
abstract BaseOptions baseOptions();
abstract ImageSegmenterOptions imageSegmenterOptions();
public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
}
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
.setImageSegmenterGraphOptions(
imageSegmenterOptions()
.convertToCalculatorOptionsProto()
.getExtension(ImageSegmenterGraphOptions.ext))
.build();
}
}
/** Options for detecting edge image. */
@AutoValue
public abstract static class EdgeConditionOptions {
/**
* Builder for {@link EdgeConditionOptions}.
*
* <p>These parameters are used to config Canny edge algorithm of OpenCV.
*
* <p>See more details:
* https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
*/
@AutoValue.Builder
public abstract static class Builder {
/** Set the base options for plugin model. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/** First threshold for the hysteresis procedure. */
public abstract Builder setThreshold1(Float threshold1);
/** Second threshold for the hysteresis procedure. */
public abstract Builder setThreshold2(Float threshold2);
/** Aperture size for the Sobel operator. Typical range is 3~7. */
public abstract Builder setApertureSize(Integer apertureSize);
/**
* flag, indicating whether a more accurate L2 norm should be used to calculate the image
* gradient magnitude ( L2gradient=true ), or whether the default L1 norm is enough (
* L2gradient=false ).
*/
public abstract Builder setL2Gradient(Boolean l2Gradient);
abstract EdgeConditionOptions autoBuild();
/** Validates and builds the {@link EdgeConditionOptions} instance. */
public final EdgeConditionOptions build() {
return autoBuild();
}
}
abstract BaseOptions baseOptions();
abstract Float threshold1();
abstract Float threshold2();
abstract Integer apertureSize();
abstract Boolean l2Gradient();
public static Builder builder() {
return new AutoValue_ImageGenerator_ConditionOptions_EdgeConditionOptions.Builder()
.setThreshold1(100f)
.setThreshold2(200f)
.setApertureSize(3)
.setL2Gradient(false);
}
ConditionedImageGraphOptions.EdgeConditionTypeOptions convertToProto() {
return ConditionedImageGraphOptions.EdgeConditionTypeOptions.newBuilder()
.setThreshold1(threshold1())
.setThreshold2(threshold2())
.setApertureSize(apertureSize())
.setL2Gradient(l2Gradient())
.build();
}
}
}
}

View File

@ -0,0 +1,44 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imagegenerator;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.Optional;
/** Represents the image generation results generated by {@link ImageGenerator}. */
@AutoValue
public abstract class ImageGeneratorResult implements TaskResult {
/** Create an {@link ImageGeneratorResult} instance from the generated image. */
public static ImageGeneratorResult create(
MPImage generatedImage, MPImage conditionImage, long timestampMs) {
return new AutoValue_ImageGeneratorResult(
generatedImage, Optional.of(conditionImage), timestampMs);
}
/** Create an {@link ImageGeneratorResult} instance from the generated image. */
public static ImageGeneratorResult create(MPImage generatedImage, long timestampMs) {
return new AutoValue_ImageGeneratorResult(generatedImage, Optional.empty(), timestampMs);
}
public abstract MPImage generatedImage();
public abstract Optional<MPImage> conditionImage();
@Override
public abstract long timestampMs();
}