Internal update
PiperOrigin-RevId: 561148365
This commit is contained in:
parent
01fbbd9f67
commit
e18e749e3e
136
mediapipe/tasks/cc/vision/image_generator/BUILD
Normal file
136
mediapipe/tasks/cc/vision/image_generator/BUILD
Normal file
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
cc_library(
|
||||
name = "conditioned_image_graph",
|
||||
srcs = ["conditioned_image_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:get_vector_item_calculator",
|
||||
"//mediapipe/calculators/core:get_vector_item_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||
"//mediapipe/calculators/util:flat_color_image_calculator",
|
||||
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_render_data_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarks_connections",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:image_frame_util",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_generator_graph",
|
||||
srcs = ["image_generator_graph.cc"],
|
||||
deps = [
|
||||
":conditioned_image_graph",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:from_image_calculator",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:stream_handler_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:switch_container",
|
||||
"//mediapipe/framework/tool:switch_container_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:diffusion_plugins_output_calculator",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:stable_diffusion_iterate_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_generator_result",
|
||||
hdrs = ["image_generator_result.h"],
|
||||
deps = ["//mediapipe/framework/formats:image"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_generator",
|
||||
srcs = ["image_generator.cc"],
|
||||
hdrs = ["image_generator.h"],
|
||||
deps = [
|
||||
":image_generator_graph",
|
||||
":image_generator_result",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,458 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
|
||||
#include "mediapipe/calculators/core/get_vector_item_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_render_data_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_format.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/image_frame_util.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_generator {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Helper postprocessing calculator for depth condition type to scale raw depth
|
||||
// inference result to 0-255 uint8.
|
||||
class DepthImagePostprocessingCalculator : public api2::Node {
|
||||
public:
|
||||
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
|
||||
static constexpr api2::Output<Image> kImageOut{"IMAGE"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (kImageIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
Image raw_depth_image = kImageIn(cc).Get();
|
||||
cv::Mat raw_depth_mat = mediapipe::formats::MatView(
|
||||
raw_depth_image.GetImageFrameSharedPtr().get());
|
||||
cv::Mat depth_mat;
|
||||
cv::normalize(raw_depth_mat, depth_mat, 255, 0, cv::NORM_MINMAX);
|
||||
depth_mat.convertTo(depth_mat, CV_8UC3, 1, 0);
|
||||
cv::cvtColor(depth_mat, depth_mat, cv::COLOR_GRAY2RGB);
|
||||
// Acquires the cv::Mat data and assign to the image frame.
|
||||
ImageFrameSharedPtr depth_image_frame_ptr = std::make_shared<ImageFrame>(
|
||||
mediapipe::ImageFormat::SRGB, depth_mat.cols, depth_mat.rows,
|
||||
depth_mat.step, depth_mat.data,
|
||||
[depth_mat](uint8_t[]) { depth_mat.~Mat(); });
|
||||
Image depth_image(depth_image_frame_ptr);
|
||||
kImageOut(cc).Send(depth_image);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
|
||||
// moved to next line.
|
||||
// clang-format off
|
||||
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::DepthImagePostprocessingCalculator);
|
||||
// clang-format on
|
||||
// NOLINTEND
|
||||
|
||||
// Calculator to detect edges in the image with OpenCV Canny edge detection.
|
||||
class CannyEdgeCalculator : public api2::Node {
|
||||
public:
|
||||
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
|
||||
static constexpr api2::Output<Image> kImageOut{"IMAGE"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kImageIn, kImageOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (kImageIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
Image input_image = kImageIn(cc).Get();
|
||||
cv::Mat input_image_mat =
|
||||
mediapipe::formats::MatView(input_image.GetImageFrameSharedPtr().get());
|
||||
const auto& options = cc->Options<
|
||||
proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>();
|
||||
cv::Mat lumincance;
|
||||
cv::cvtColor(input_image_mat, lumincance, cv::COLOR_RGB2GRAY);
|
||||
cv::Mat edges_mat;
|
||||
cv::Canny(lumincance, edges_mat, options.threshold_1(),
|
||||
options.threshold_2(), options.aperture_size(),
|
||||
options.l2_gradient());
|
||||
cv::normalize(edges_mat, edges_mat, 255, 0, cv::NORM_MINMAX);
|
||||
edges_mat.convertTo(edges_mat, CV_8UC3, 1, 0);
|
||||
cv::cvtColor(edges_mat, edges_mat, cv::COLOR_GRAY2RGB);
|
||||
// Acquires the cv::Mat data and assign to the image frame.
|
||||
ImageFrameSharedPtr edges_image_frame_ptr = std::make_shared<ImageFrame>(
|
||||
mediapipe::ImageFormat::SRGB, edges_mat.cols, edges_mat.rows,
|
||||
edges_mat.step, edges_mat.data,
|
||||
[edges_mat](uint8_t[]) { edges_mat.~Mat(); });
|
||||
Image edges_image(edges_image_frame_ptr);
|
||||
kImageOut(cc).Send(edges_image);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
|
||||
// moved to next line.
|
||||
// clang-format off
|
||||
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::vision::image_generator::internal::CannyEdgeCalculator);
|
||||
// clang-format on
|
||||
// NOLINTEND
|
||||
|
||||
} // namespace internal
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
|
||||
constexpr absl::string_view kImageTag = "IMAGE";
|
||||
constexpr absl::string_view kUImageTag = "UIMAGE";
|
||||
constexpr absl::string_view kNormLandmarksTag = "NORM_LANDMARKS";
|
||||
constexpr absl::string_view kVectorTag = "VECTOR";
|
||||
constexpr absl::string_view kItemTag = "ITEM";
|
||||
constexpr absl::string_view kRenderDataTag = "RENDER_DATA";
|
||||
constexpr absl::string_view kConfidenceMaskTag = "CONFIDENCE_MASK:0";
|
||||
|
||||
enum ColorType {
|
||||
WHITE = 0,
|
||||
GREEN = 1,
|
||||
RED = 2,
|
||||
BLACK = 3,
|
||||
BLUE = 4,
|
||||
};
|
||||
|
||||
mediapipe::Color GetColor(ColorType color_type) {
|
||||
mediapipe::Color color;
|
||||
switch (color_type) {
|
||||
case WHITE:
|
||||
color.set_b(255);
|
||||
color.set_g(255);
|
||||
color.set_r(255);
|
||||
break;
|
||||
case GREEN:
|
||||
color.set_b(0);
|
||||
color.set_g(255);
|
||||
color.set_r(0);
|
||||
break;
|
||||
case RED:
|
||||
color.set_b(0);
|
||||
color.set_g(0);
|
||||
color.set_r(255);
|
||||
break;
|
||||
case BLACK:
|
||||
color.set_b(0);
|
||||
color.set_g(0);
|
||||
color.set_r(0);
|
||||
break;
|
||||
case BLUE:
|
||||
color.set_b(255);
|
||||
color.set_g(0);
|
||||
color.set_r(0);
|
||||
break;
|
||||
}
|
||||
return color;
|
||||
}
|
||||
|
||||
// Get LandmarksToRenderDataCalculatorOptions for rendering face landmarks
|
||||
// connections.
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>> connections, ColorType color_type) {
|
||||
mediapipe::LandmarksToRenderDataCalculatorOptions render_options;
|
||||
render_options.set_thickness(1);
|
||||
render_options.set_visualize_landmark_depth(false);
|
||||
render_options.set_render_landmarks(false);
|
||||
*render_options.mutable_connection_color() = GetColor(color_type);
|
||||
for (const auto& connection : connections) {
|
||||
render_options.add_landmark_connections(connection[0]);
|
||||
render_options.add_landmark_connections(connection[1]);
|
||||
}
|
||||
return render_options;
|
||||
}
|
||||
|
||||
Source<mediapipe::RenderData> GetFaceLandmarksRenderData(
|
||||
Source<mediapipe::NormalizedLandmarkList> face_landmarks,
|
||||
const mediapipe::LandmarksToRenderDataCalculatorOptions&
|
||||
landmarks_to_render_data_options,
|
||||
Graph& graph) {
|
||||
auto& landmarks_to_render_data =
|
||||
graph.AddNode("LandmarksToRenderDataCalculator");
|
||||
landmarks_to_render_data
|
||||
.GetOptions<mediapipe::LandmarksToRenderDataCalculatorOptions>()
|
||||
.CopyFrom(landmarks_to_render_data_options);
|
||||
face_landmarks >> landmarks_to_render_data.In(kNormLandmarksTag);
|
||||
return landmarks_to_render_data.Out(kRenderDataTag)
|
||||
.Cast<mediapipe::RenderData>();
|
||||
}
|
||||
|
||||
// Add FaceLandmarkerGraph to detect the face landmarks in the given face image,
|
||||
// and generate a face mesh guidance image for the diffusion plugin model.
|
||||
absl::StatusOr<Source<Image>> GetFaceLandmarksImage(
|
||||
Source<Image> face_image,
|
||||
const proto::ConditionedImageGraphOptions::FaceConditionTypeOptions&
|
||||
face_condition_type_options,
|
||||
Graph& graph) {
|
||||
if (face_condition_type_options.face_landmarker_graph_options()
|
||||
.face_detector_graph_options()
|
||||
.num_faces() != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Only supports face landmarks of a single face as the guidance image.");
|
||||
}
|
||||
|
||||
// Detect face landmarks.
|
||||
auto& face_landmarker_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph");
|
||||
face_landmarker_graph
|
||||
.GetOptions<face_landmarker::proto::FaceLandmarkerGraphOptions>()
|
||||
.CopyFrom(face_condition_type_options.face_landmarker_graph_options());
|
||||
face_image >> face_landmarker_graph.In(kImageTag);
|
||||
auto face_landmarks_lists =
|
||||
face_landmarker_graph.Out(kNormLandmarksTag)
|
||||
.Cast<std::vector<mediapipe::NormalizedLandmarkList>>();
|
||||
|
||||
// Get the single face landmarks.
|
||||
auto& get_vector_item =
|
||||
graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator");
|
||||
get_vector_item.GetOptions<mediapipe::GetVectorItemCalculatorOptions>()
|
||||
.set_item_index(0);
|
||||
face_landmarks_lists >> get_vector_item.In(kVectorTag);
|
||||
auto single_face_landmarks =
|
||||
get_vector_item.Out(kItemTag).Cast<mediapipe::NormalizedLandmarkList>();
|
||||
|
||||
// Convert face landmarks to render data.
|
||||
auto face_oval = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval
|
||||
.data(),
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksFaceOval
|
||||
.size()),
|
||||
ColorType::WHITE),
|
||||
graph);
|
||||
auto lips = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips
|
||||
.data(),
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLips
|
||||
.size()),
|
||||
ColorType::WHITE),
|
||||
graph);
|
||||
auto left_eye = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye
|
||||
.data(),
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftEye
|
||||
.size()),
|
||||
ColorType::GREEN),
|
||||
graph);
|
||||
auto left_eye_brow = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::
|
||||
kFaceLandmarksLeftEyeBrow.data(),
|
||||
face_landmarker::FaceLandmarksConnections::
|
||||
kFaceLandmarksLeftEyeBrow.size()),
|
||||
ColorType::GREEN),
|
||||
graph);
|
||||
auto left_iris = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris
|
||||
.data(),
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksLeftIris
|
||||
.size()),
|
||||
ColorType::GREEN),
|
||||
graph);
|
||||
|
||||
auto right_eye = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye
|
||||
.data(),
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightEye
|
||||
.size()),
|
||||
ColorType::BLUE),
|
||||
graph);
|
||||
auto right_eye_brow = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::
|
||||
kFaceLandmarksRightEyeBrow.data(),
|
||||
face_landmarker::FaceLandmarksConnections::
|
||||
kFaceLandmarksRightEyeBrow.size()),
|
||||
ColorType::BLUE),
|
||||
graph);
|
||||
auto right_iris = GetFaceLandmarksRenderData(
|
||||
single_face_landmarks,
|
||||
GetFaceLandmarksRenderDataOptions(
|
||||
absl::Span<const std::array<int, 2>>(
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris
|
||||
.data(),
|
||||
face_landmarker::FaceLandmarksConnections::kFaceLandmarksRightIris
|
||||
.size()),
|
||||
ColorType::BLUE),
|
||||
graph);
|
||||
|
||||
// Create a black canvas image with same size as face image.
|
||||
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
|
||||
flat_color.GetOptions<mediapipe::FlatColorImageCalculatorOptions>()
|
||||
.mutable_color()
|
||||
->set_r(0);
|
||||
face_image >> flat_color.In(kImageTag);
|
||||
auto blank_canvas = flat_color.Out(kImageTag);
|
||||
|
||||
// Draw render data on the canvas image.
|
||||
auto& annotation_overlay = graph.AddNode("AnnotationOverlayCalculator");
|
||||
blank_canvas >> annotation_overlay.In(kUImageTag);
|
||||
face_oval >> annotation_overlay.In(0);
|
||||
lips >> annotation_overlay.In(1);
|
||||
left_eye >> annotation_overlay.In(2);
|
||||
left_eye_brow >> annotation_overlay.In(3);
|
||||
left_iris >> annotation_overlay.In(4);
|
||||
right_eye >> annotation_overlay.In(5);
|
||||
right_eye_brow >> annotation_overlay.In(6);
|
||||
right_iris >> annotation_overlay.In(7);
|
||||
return annotation_overlay.Out(kUImageTag).Cast<Image>();
|
||||
}
|
||||
|
||||
absl::StatusOr<Source<Image>> GetDepthImage(
|
||||
Source<Image> image,
|
||||
const image_generator::proto::ConditionedImageGraphOptions::
|
||||
DepthConditionTypeOptions& depth_condition_type_options,
|
||||
Graph& graph) {
|
||||
auto& image_segmenter_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
|
||||
image_segmenter_graph
|
||||
.GetOptions<image_segmenter::proto::ImageSegmenterGraphOptions>()
|
||||
.CopyFrom(depth_condition_type_options.image_segmenter_graph_options());
|
||||
image >> image_segmenter_graph.In(kImageTag);
|
||||
auto raw_depth_image = image_segmenter_graph.Out(kConfidenceMaskTag);
|
||||
|
||||
auto& depth_postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.vision.image_generator.internal."
|
||||
"DepthImagePostprocessingCalculator");
|
||||
raw_depth_image >> depth_postprocessing.In(kImageTag);
|
||||
return depth_postprocessing.Out(kImageTag).Cast<Image>();
|
||||
}
|
||||
|
||||
absl::StatusOr<Source<Image>> GetEdgeImage(
|
||||
Source<Image> image,
|
||||
const image_generator::proto::ConditionedImageGraphOptions::
|
||||
EdgeConditionTypeOptions& edge_condition_type_options,
|
||||
Graph& graph) {
|
||||
auto& edge_detector = graph.AddNode(
|
||||
"mediapipe.tasks.vision.image_generator.internal."
|
||||
"CannyEdgeCalculator");
|
||||
edge_detector
|
||||
.GetOptions<
|
||||
proto::ConditionedImageGraphOptions::EdgeConditionTypeOptions>()
|
||||
.CopyFrom(edge_condition_type_options);
|
||||
image >> edge_detector.In(kImageTag);
|
||||
return edge_detector.Out(kImageTag).Cast<Image>();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A mediapipe.tasks.vision.image_generator.ConditionedImageGraph converts the
|
||||
// input image to an image of condition type. The output image can be used as
|
||||
// input for the diffusion model with control plugin.
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Conditioned image to generate the image for diffusion plugin model.
|
||||
//
|
||||
// Outputs:
|
||||
// IMAGE - Image
|
||||
// The guidance image used as input for the diffusion plugin model.
|
||||
class ConditionedImageGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
auto& graph_options =
|
||||
*sc->MutableOptions<proto::ConditionedImageGraphOptions>();
|
||||
Source<Image> conditioned_image = graph.In(kImageTag).Cast<Image>();
|
||||
// Configure the guidance graph and get the guidance image if guidance graph
|
||||
// options is set.
|
||||
switch (graph_options.condition_type_options_case()) {
|
||||
case proto::ConditionedImageGraphOptions::CONDITION_TYPE_OPTIONS_NOT_SET:
|
||||
return absl::InvalidArgumentError(
|
||||
"Conditioned type options is not set.");
|
||||
break;
|
||||
case proto::ConditionedImageGraphOptions::kFaceConditionTypeOptions: {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto face_landmarks_image,
|
||||
GetFaceLandmarksImage(conditioned_image,
|
||||
graph_options.face_condition_type_options(),
|
||||
graph));
|
||||
face_landmarks_image >> graph.Out(kImageTag);
|
||||
} break;
|
||||
case proto::ConditionedImageGraphOptions::kDepthConditionTypeOptions: {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto depth_image,
|
||||
GetDepthImage(conditioned_image,
|
||||
graph_options.depth_condition_type_options(), graph));
|
||||
depth_image >> graph.Out(kImageTag);
|
||||
} break;
|
||||
case proto::ConditionedImageGraphOptions::kEdgeConditionTypeOptions: {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto edges_image,
|
||||
GetEdgeImage(conditioned_image,
|
||||
graph_options.edge_condition_type_options(), graph));
|
||||
edges_image >> graph.Out(kImageTag);
|
||||
} break;
|
||||
}
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_generator::ConditionedImageGraph);
|
||||
|
||||
} // namespace image_generator
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,147 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_generator {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::vision::DecodeImageFromFile;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kFaceLandmarkerModel[] = "face_landmarker_v2.task";
|
||||
constexpr char kDepthModel[] =
|
||||
"mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt.tflite";
|
||||
constexpr char kPortraitImage[] = "portrait.jpg";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageInStream[] = "image_in";
|
||||
constexpr char kImageOutStream[] = "image_out";
|
||||
|
||||
// Helper function to create a ConditionedImageGraphTaskRunner TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>>
|
||||
CreateConditionedImageGraphTaskRunner(
|
||||
std::unique_ptr<proto::ConditionedImageGraphOptions> options) {
|
||||
Graph graph;
|
||||
auto& conditioned_image_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraph");
|
||||
conditioned_image_graph.GetOptions<proto::ConditionedImageGraphOptions>()
|
||||
.Swap(options.get());
|
||||
graph.In(kImageTag).Cast<Image>().SetName(kImageInStream) >>
|
||||
conditioned_image_graph.In(kImageTag);
|
||||
conditioned_image_graph.Out(kImageTag).SetName(kImageOutStream) >>
|
||||
graph.Out(kImageTag).Cast<Image>();
|
||||
return core::TaskRunner::Create(
|
||||
graph.GetConfig(),
|
||||
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
TEST(ConditionedImageGraphTest, SucceedsFaceLandmarkerConditionType) {
|
||||
auto options = std::make_unique<proto::ConditionedImageGraphOptions>();
|
||||
options->mutable_face_condition_type_options()
|
||||
->mutable_face_landmarker_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(
|
||||
file::JoinPath("./", kTestDataDirectory, kFaceLandmarkerModel));
|
||||
options->mutable_face_condition_type_options()
|
||||
->mutable_face_landmarker_graph_options()
|
||||
->mutable_face_detector_graph_options()
|
||||
->set_num_faces(1);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto runner, CreateConditionedImageGraphTaskRunner(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory,
|
||||
kPortraitImage)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto output_packets,
|
||||
runner->Process({{kImageInStream, MakePacket<Image>(std::move(image))}}));
|
||||
const auto& output_image = output_packets[kImageOutStream].Get<Image>();
|
||||
MP_EXPECT_OK(SavePngTestOutput(*output_image.GetImageFrameSharedPtr(),
|
||||
"face_landmarks_image"));
|
||||
}
|
||||
|
||||
TEST(ConditionedImageGraphTest, SucceedsDepthConditionType) {
|
||||
auto options = std::make_unique<proto::ConditionedImageGraphOptions>();
|
||||
options->mutable_depth_condition_type_options()
|
||||
->mutable_image_segmenter_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(file::JoinPath("./", kTestDataDirectory, kDepthModel));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory,
|
||||
kPortraitImage)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto runner, CreateConditionedImageGraphTaskRunner(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto output_packets,
|
||||
runner->Process({{kImageInStream, MakePacket<Image>(std::move(image))}}));
|
||||
const auto& output_image = output_packets[kImageOutStream].Get<Image>();
|
||||
MP_EXPECT_OK(
|
||||
SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "depth_image"));
|
||||
}
|
||||
|
||||
TEST(ConditionedImageGraphTest, SucceedsEdgeConditionType) {
|
||||
auto options = std::make_unique<proto::ConditionedImageGraphOptions>();
|
||||
auto edge_condition_type_options =
|
||||
options->mutable_edge_condition_type_options();
|
||||
edge_condition_type_options->set_threshold_1(100);
|
||||
edge_condition_type_options->set_threshold_2(200);
|
||||
edge_condition_type_options->set_aperture_size(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(file::JoinPath("./", kTestDataDirectory,
|
||||
kPortraitImage)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto runner, CreateConditionedImageGraphTaskRunner(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto output_packets,
|
||||
runner->Process({{kImageInStream, MakePacket<Image>(std::move(image))}}));
|
||||
const auto& output_image = output_packets[kImageOutStream].Get<Image>();
|
||||
MP_EXPECT_OK(
|
||||
SavePngTestOutput(*output_image.GetImageFrameSharedPtr(), "edges_image"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace image_generator
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
70
mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD
Normal file
70
mediapipe/tasks/cc/vision/image_generator/diffuser/BUILD
Normal file
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
cc_library(
|
||||
name = "diffuser_gpu_header",
|
||||
hdrs = ["diffuser_gpu.h"],
|
||||
visibility = [
|
||||
"//mediapipe/tasks/cc/vision/image_generator/diffuser:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "stable_diffusion_iterate_calculator_proto",
|
||||
srcs = ["stable_diffusion_iterate_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stable_diffusion_iterate_calculator",
|
||||
srcs = ["stable_diffusion_iterate_calculator.cc"],
|
||||
deps = [
|
||||
":diffuser_gpu_header",
|
||||
":stable_diffusion_iterate_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/deps:file_helpers",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "diffusion_plugins_output_calculator",
|
||||
srcs = ["diffusion_plugins_output_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_
|
||||
|
||||
#include <limits.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef DG_EXPORT
|
||||
#define DG_EXPORT __attribute__((visibility("default")))
|
||||
#endif // DG_EXPORT
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
enum DiffuserModelType {
|
||||
kDiffuserModelTypeSd1,
|
||||
kDiffuserModelTypeGldm,
|
||||
kDiffuserModelTypeDistilledGldm,
|
||||
kDiffuserModelTypeSd2Base,
|
||||
kDiffuserModelTypeTigo,
|
||||
};
|
||||
|
||||
enum DiffuserPriorityHint {
|
||||
kDiffuserPriorityHintHigh,
|
||||
kDiffuserPriorityHintNormal,
|
||||
kDiffuserPriorityHintLow,
|
||||
};
|
||||
|
||||
enum DiffuserPerformanceHint {
|
||||
kDiffuserPerformanceHintHigh,
|
||||
kDiffuserPerformanceHintNormal,
|
||||
kDiffuserPerformanceHintLow,
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
DiffuserPriorityHint priority_hint;
|
||||
DiffuserPerformanceHint performance_hint;
|
||||
} DiffuserEnvironmentOptions;
|
||||
|
||||
typedef struct {
|
||||
DiffuserModelType model_type;
|
||||
char model_dir[PATH_MAX];
|
||||
char lora_dir[PATH_MAX];
|
||||
const void* lora_weights_layer_mapping;
|
||||
int lora_rank;
|
||||
int seed;
|
||||
int image_width;
|
||||
int image_height;
|
||||
int run_unet_with_plugins;
|
||||
float plugins_strength;
|
||||
DiffuserEnvironmentOptions env_options;
|
||||
} DiffuserConfig;
|
||||
|
||||
typedef struct {
|
||||
void* diffuser;
|
||||
} DiffuserContext;
|
||||
|
||||
typedef struct {
|
||||
int shape[4];
|
||||
const float* data;
|
||||
} DiffuserPluginTensor;
|
||||
|
||||
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
|
||||
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
|
||||
const char*, int, int, const void*);
|
||||
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
|
||||
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
|
||||
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_DIFFUSER_DIFFUSER_GPU_H_
|
|
@ -0,0 +1,67 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// In iteration mode, output the image guidance tensors at the current timestamp
|
||||
// and advance the output stream timestamp bound by the number of steps.
|
||||
// Otherwise, output the image guidance tensors at the current timestamp only.
|
||||
class DiffusionPluginsOutputCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||
static constexpr Input<int> kStepsIn{"STEPS"};
|
||||
static constexpr Input<int>::Optional kIterationIn{"ITERATION"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kStepsIn, kIterationIn, kTensorsOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
if (kTensorsIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Consumes the tensor vector to avoid data copy.
|
||||
absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> status_or_tensor =
|
||||
cc->Inputs().Tag("TENSORS").Value().Consume<std::vector<Tensor>>();
|
||||
if (!status_or_tensor.ok()) {
|
||||
return absl::InternalError("Input tensor vector is not consumable.");
|
||||
}
|
||||
if (kIterationIn(cc).IsConnected()) {
|
||||
CHECK_EQ(kIterationIn(cc).Get(), 0);
|
||||
kTensorsOut(cc).Send(std::move(*status_or_tensor.value()));
|
||||
kTensorsOut(cc).SetNextTimestampBound(cc->InputTimestamp() +
|
||||
kStepsIn(cc).Get());
|
||||
} else {
|
||||
kTensorsOut(cc).Send(std::move(*status_or_tensor.value()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(DiffusionPluginsOutputCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,278 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_helpers.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
DiffuserPriorityHint ToDiffuserPriorityHint(
|
||||
StableDiffusionIterateCalculatorOptions::ClPriorityHint priority) {
|
||||
switch (priority) {
|
||||
case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_LOW:
|
||||
return kDiffuserPriorityHintLow;
|
||||
case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_NORMAL:
|
||||
return kDiffuserPriorityHintNormal;
|
||||
case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_HIGH:
|
||||
return kDiffuserPriorityHintHigh;
|
||||
}
|
||||
return kDiffuserPriorityHintNormal;
|
||||
}
|
||||
|
||||
DiffuserModelType ToDiffuserModelType(
|
||||
StableDiffusionIterateCalculatorOptions::ModelType model_type) {
|
||||
switch (model_type) {
|
||||
case StableDiffusionIterateCalculatorOptions::DEFAULT:
|
||||
case StableDiffusionIterateCalculatorOptions::SD_1:
|
||||
return kDiffuserModelTypeSd1;
|
||||
}
|
||||
return kDiffuserModelTypeSd1;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Runs diffusion models including, but not limited to, Stable Diffusion & gLDM.
|
||||
//
|
||||
// Inputs:
|
||||
// PROMPT - std::string
|
||||
// The prompt used to generate the image.
|
||||
// STEPS - int
|
||||
// The number of steps to run the UNet.
|
||||
// ITERATION - int
|
||||
// The iteration of the current run.
|
||||
// PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
|
||||
// The output tensor vector of the diffusion plugins model.
|
||||
//
|
||||
// Outputs:
|
||||
// IMAGE - mediapipe::ImageFrame
|
||||
// The image generated by the Stable Diffusion model from the input prompt.
|
||||
// The output image is in RGB format.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "StableDiffusionIterateCalculator"
|
||||
// input_stream: "PROMPT:prompt"
|
||||
// input_stream: "STEPS:steps"
|
||||
// output_stream: "IMAGE:result"
|
||||
// options {
|
||||
// [mediapipe.StableDiffusionIterateCalculatorOptions.ext] {
|
||||
// base_seed: 0
|
||||
// model_type: SD_1
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class StableDiffusionIterateCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kPromptIn{"PROMPT"};
|
||||
static constexpr Input<int> kStepsIn{"STEPS"};
|
||||
static constexpr Input<int>::Optional kIterationIn{"ITERATION"};
|
||||
static constexpr Input<int>::Optional kRandSeedIn{"RAND_SEED"};
|
||||
static constexpr SideInput<StableDiffusionIterateCalculatorOptions>::Optional
|
||||
kOptionsIn{"OPTIONS"};
|
||||
static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
|
||||
"PLUGIN_TENSORS"};
|
||||
static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
|
||||
kPlugInTensorsIn, kOptionsIn, kImageOut);
|
||||
|
||||
~StableDiffusionIterateCalculator() {
|
||||
if (context_) DiffuserDelete();
|
||||
if (handle_) dlclose(handle_);
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::vector<DiffuserPluginTensor> GetPluginTensors(
|
||||
CalculatorContext* cc) const {
|
||||
if (!kPlugInTensorsIn(cc).IsConnected()) return {};
|
||||
std::vector<DiffuserPluginTensor> diffuser_tensors;
|
||||
diffuser_tensors.reserve(kPlugInTensorsIn(cc)->size());
|
||||
for (const auto& mp_tensor : *kPlugInTensorsIn(cc)) {
|
||||
DiffuserPluginTensor diffuser_tensor;
|
||||
diffuser_tensor.shape[0] = mp_tensor.shape().dims[0];
|
||||
diffuser_tensor.shape[1] = mp_tensor.shape().dims[1];
|
||||
diffuser_tensor.shape[2] = mp_tensor.shape().dims[2];
|
||||
diffuser_tensor.shape[3] = mp_tensor.shape().dims[3];
|
||||
diffuser_tensor.data = mp_tensor.GetCpuReadView().buffer<float>();
|
||||
diffuser_tensors.push_back(diffuser_tensor);
|
||||
}
|
||||
return diffuser_tensors;
|
||||
}
|
||||
|
||||
absl::Status LoadDiffuser() {
|
||||
handle_ = dlopen("libimagegenerator_gpu.so", RTLD_NOW | RTLD_LOCAL);
|
||||
RET_CHECK(handle_) << dlerror();
|
||||
create_ptr_ = reinterpret_cast<DiffuserContext* (*)(const DiffuserConfig*)>(
|
||||
dlsym(handle_, "DiffuserCreate"));
|
||||
RET_CHECK(create_ptr_) << dlerror();
|
||||
reset_ptr_ =
|
||||
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int,
|
||||
const void*)>(dlsym(handle_, "DiffuserReset"));
|
||||
RET_CHECK(reset_ptr_) << dlerror();
|
||||
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
|
||||
dlsym(handle_, "DiffuserIterate"));
|
||||
RET_CHECK(iterate_ptr_) << dlerror();
|
||||
decode_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, uint8_t*)>(
|
||||
dlsym(handle_, "DiffuserDecode"));
|
||||
RET_CHECK(decode_ptr_) << dlerror();
|
||||
delete_ptr_ = reinterpret_cast<void (*)(DiffuserContext*)>(
|
||||
dlsym(handle_, "DiffuserDelete"));
|
||||
RET_CHECK(delete_ptr_) << dlerror();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
|
||||
return (*create_ptr_)(a);
|
||||
}
|
||||
bool DiffuserReset(const char* a, int b, int c,
|
||||
const std::vector<DiffuserPluginTensor>* d) {
|
||||
return (*reset_ptr_)(context_, a, b, c, d);
|
||||
}
|
||||
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
|
||||
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
|
||||
void DiffuserDelete() { (*delete_ptr_)(context_); }
|
||||
|
||||
void* handle_ = nullptr;
|
||||
DiffuserContext* context_ = nullptr;
|
||||
DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
|
||||
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*);
|
||||
int (*iterate_ptr_)(DiffuserContext*, int, int);
|
||||
int (*decode_ptr_)(DiffuserContext*, uint8_t*);
|
||||
void (*delete_ptr_)(DiffuserContext*);
|
||||
|
||||
int show_every_n_iteration_;
|
||||
bool emit_empty_packet_;
|
||||
};
|
||||
|
||||
absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
|
||||
StableDiffusionIterateCalculatorOptions options;
|
||||
if (kOptionsIn(cc).IsEmpty()) {
|
||||
options = cc->Options<StableDiffusionIterateCalculatorOptions>();
|
||||
} else {
|
||||
options = kOptionsIn(cc).Get();
|
||||
}
|
||||
show_every_n_iteration_ = options.show_every_n_iteration();
|
||||
emit_empty_packet_ = options.emit_empty_packet();
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadDiffuser());
|
||||
|
||||
DiffuserConfig config;
|
||||
config.model_type = ToDiffuserModelType(options.model_type());
|
||||
if (options.file_folder().empty()) {
|
||||
std::strcpy(config.model_dir, "bins/"); // NOLINT
|
||||
} else {
|
||||
std::strcpy(config.model_dir, options.file_folder().c_str()); // NOLINT
|
||||
}
|
||||
MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir))
|
||||
<< config.model_dir;
|
||||
RET_CHECK(options.lora_file_folder().empty() ||
|
||||
options.lora_weights_layer_mapping().empty())
|
||||
<< "Can't set both lora_file_folder and lora_weights_layer_mapping.";
|
||||
std::strcpy(config.lora_dir, options.lora_file_folder().c_str()); // NOLINT
|
||||
std::map<std::string, const char*> lora_weights_layer_mapping;
|
||||
for (auto& layer_name_and_weights : options.lora_weights_layer_mapping()) {
|
||||
lora_weights_layer_mapping[layer_name_and_weights.first] =
|
||||
(char*)layer_name_and_weights.second;
|
||||
}
|
||||
config.lora_weights_layer_mapping = !lora_weights_layer_mapping.empty()
|
||||
? &lora_weights_layer_mapping
|
||||
: nullptr;
|
||||
config.lora_rank = options.lora_rank();
|
||||
config.seed = options.base_seed();
|
||||
config.image_width = options.output_image_width();
|
||||
config.image_height = options.output_image_height();
|
||||
config.run_unet_with_plugins = kPlugInTensorsIn(cc).IsConnected();
|
||||
config.env_options = {
|
||||
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
|
||||
.performance_hint = kDiffuserPerformanceHintHigh,
|
||||
};
|
||||
config.plugins_strength = options.plugins_strength();
|
||||
RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f)
|
||||
<< "The value of plugins_strength must be in the range of [0, 1].";
|
||||
context_ = DiffuserCreate(&config);
|
||||
RET_CHECK(context_);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
||||
const auto& options =
|
||||
cc->Options().GetExtension(StableDiffusionIterateCalculatorOptions::ext);
|
||||
const std::string& prompt = *kPromptIn(cc);
|
||||
const int steps = *kStepsIn(cc);
|
||||
const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
|
||||
: options.base_seed();
|
||||
|
||||
if (kIterationIn(cc).IsEmpty()) {
|
||||
const auto plugin_tensors = GetPluginTensors(cc);
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
|
||||
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||
options.output_image_height());
|
||||
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
|
||||
kImageOut(cc).Send(std::move(image_out));
|
||||
} else {
|
||||
const int iteration = *kIterationIn(cc);
|
||||
RET_CHECK_LT(iteration, steps);
|
||||
|
||||
// Extract text embedding on first iteration.
|
||||
if (iteration == 0) {
|
||||
const auto plugin_tensors = GetPluginTensors(cc);
|
||||
RET_CHECK(
|
||||
DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
|
||||
}
|
||||
|
||||
RET_CHECK(DiffuserIterate(steps, iteration));
|
||||
|
||||
// Decode the output and send out the image for visualization.
|
||||
if ((iteration + 1) % show_every_n_iteration_ == 0 ||
|
||||
iteration == steps - 1) {
|
||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||
options.output_image_height());
|
||||
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
|
||||
kImageOut(cc).Send(std::move(image_out));
|
||||
} else if (emit_empty_packet_) {
|
||||
kImageOut(cc).Send(Packet<mediapipe::ImageFrame>());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(StableDiffusionIterateCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,84 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.calculator.proto";
|
||||
option java_outer_classname = "StableDiffusionIterateCalculatorOptionsProto";
|
||||
|
||||
message StableDiffusionIterateCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional StableDiffusionIterateCalculatorOptions ext = 510855836;
|
||||
}
|
||||
|
||||
// The random seed that is fed into the calculator to control the randomness
|
||||
// of the generated image.
|
||||
optional uint32 base_seed = 1 [default = 0];
|
||||
|
||||
// The target output image size. Must be a multiple of 8 and larger than 384.
|
||||
optional int32 output_image_width = 2 [default = 512];
|
||||
optional int32 output_image_height = 3 [default = 512];
|
||||
|
||||
// The folder name must end of '/'.
|
||||
optional string file_folder = 4 [default = "bins/"];
|
||||
|
||||
// Note: only one of lora_file_folder and lora_weights_layer_mapping should be
|
||||
// set.
|
||||
// The LoRA file folder. The folder name must end of '/'.
|
||||
optional string lora_file_folder = 9 [default = ""];
|
||||
|
||||
// The LoRA layer name mapping to the weight buffer position in the file.
|
||||
map<string, uint64> lora_weights_layer_mapping = 10;
|
||||
|
||||
// The LoRA rank.
|
||||
optional int32 lora_rank = 12 [default = 4];
|
||||
|
||||
// Determine when to run image decoding for how many every iterations.
|
||||
// Setting this to 1 means we run the image decoding for every iteration for
|
||||
// displaying the intermediate result, but it will also introduce much higher
|
||||
// overall latency.
|
||||
// Setting this to be the targeted number of iterations will only run the
|
||||
// image decoding at the end, giving the best overall latency.
|
||||
optional int32 show_every_n_iteration = 5 [default = 1];
|
||||
|
||||
// If set to be True, the calculator will perform a GPU-CPU sync and emit an
|
||||
// empty packet. It is used to provide the signal of which iterations it is
|
||||
// currently at, typically used to create a progress bar. Note that this also
|
||||
// introduce overhead, but not significanly based on our experiments (~1ms).
|
||||
optional bool emit_empty_packet = 6 [default = false];
|
||||
|
||||
enum ClPriorityHint {
|
||||
PRIORITY_HINT_NORMAL = 0; // Default, must be first.
|
||||
PRIORITY_HINT_LOW = 1;
|
||||
PRIORITY_HINT_HIGH = 2;
|
||||
}
|
||||
|
||||
// OpenCL priority hint. Set this to LOW to yield to other GPU contexts.
|
||||
// This lowers inference speed, but helps keeping the UI responsive.
|
||||
optional ClPriorityHint cl_priority_hint = 7;
|
||||
|
||||
enum ModelType {
|
||||
DEFAULT = 0;
|
||||
SD_1 = 1; // Stable Diffusion v1 models, including SD 1.4 and 1.5.
|
||||
}
|
||||
// Stable Diffusion model type. Default to Stable Diffusion v1.
|
||||
optional ModelType model_type = 8 [default = SD_1];
|
||||
// The strength of the diffusion plugins inputs.
|
||||
optional float plugins_strength = 11 [default = 1.0];
|
||||
}
|
397
mediapipe/tasks/cc/vision/image_generator/image_generator.cc
Normal file
397
mediapipe/tasks/cc/vision/image_generator/image_generator.cc
Normal file
|
@ -0,0 +1,397 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/image_generator.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_generator {
|
||||
namespace {
|
||||
|
||||
using ImageGeneratorGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_generator::proto::ImageGeneratorGraphOptions;
|
||||
using ConditionedImageGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_generator::proto::ConditionedImageGraphOptions;
|
||||
using ControlPluginGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_generator::proto::ControlPluginGraphOptions;
|
||||
using FaceLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
face_landmarker::proto::FaceLandmarkerGraphOptions;
|
||||
|
||||
constexpr absl::string_view kImageTag = "IMAGE";
|
||||
constexpr absl::string_view kImageOutName = "image_out";
|
||||
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
|
||||
constexpr absl::string_view kConditionImageName = "condition_image";
|
||||
constexpr absl::string_view kSourceConditionImageName =
|
||||
"source_condition_image";
|
||||
constexpr absl::string_view kStepsTag = "STEPS";
|
||||
constexpr absl::string_view kStepsName = "steps";
|
||||
constexpr absl::string_view kIterationTag = "ITERATION";
|
||||
constexpr absl::string_view kIterationName = "iteration";
|
||||
constexpr absl::string_view kPromptTag = "PROMPT";
|
||||
constexpr absl::string_view kPromptName = "prompt";
|
||||
constexpr absl::string_view kRandSeedTag = "RAND_SEED";
|
||||
constexpr absl::string_view kRandSeedName = "rand_seed";
|
||||
constexpr absl::string_view kSelectTag = "SELECT";
|
||||
constexpr absl::string_view kSelectName = "select";
|
||||
|
||||
constexpr char kImageGeneratorGraphTypeName[] =
|
||||
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
|
||||
|
||||
constexpr char kConditionedImageGraphContainerTypeName[] =
|
||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph".
|
||||
CalculatorGraphConfig CreateImageGeneratorGraphConfig(
|
||||
std::unique_ptr<ImageGeneratorGraphOptionsProto> options,
|
||||
bool use_condition_image) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kImageGeneratorGraphTypeName);
|
||||
subgraph.GetOptions<ImageGeneratorGraphOptionsProto>().CopyFrom(*options);
|
||||
graph.In(kStepsTag).SetName(kStepsName) >> subgraph.In(kStepsTag);
|
||||
graph.In(kIterationTag).SetName(kIterationName) >> subgraph.In(kIterationTag);
|
||||
graph.In(kPromptTag).SetName(kPromptName) >> subgraph.In(kPromptTag);
|
||||
graph.In(kRandSeedTag).SetName(kRandSeedName) >> subgraph.In(kRandSeedTag);
|
||||
if (use_condition_image) {
|
||||
graph.In(kConditionImageTag).SetName(kConditionImageName) >>
|
||||
subgraph.In(kConditionImageTag);
|
||||
graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag);
|
||||
}
|
||||
subgraph.Out(kImageTag).SetName(kImageOutName) >>
|
||||
graph[api2::Output<Image>::Optional(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer".
|
||||
CalculatorGraphConfig CreateConditionedImageGraphContainerConfig(
|
||||
std::unique_ptr<ImageGeneratorGraphOptionsProto> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kConditionedImageGraphContainerTypeName);
|
||||
subgraph.GetOptions<ImageGeneratorGraphOptionsProto>().CopyFrom(*options);
|
||||
graph.In(kImageTag).SetName(kSourceConditionImageName) >>
|
||||
subgraph.In(kImageTag);
|
||||
graph.In(kSelectTag).SetName(kSelectName) >> subgraph.In(kSelectTag);
|
||||
subgraph.Out(kConditionImageTag).SetName(kConditionImageName) >>
|
||||
graph.Out(kConditionImageTag).Cast<Image>();
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
absl::Status SetFaceConditionOptionsToProto(
|
||||
FaceConditionOptions& face_condition_options,
|
||||
ControlPluginGraphOptionsProto& options_proto) {
|
||||
// Configure face plugin model.
|
||||
auto plugin_base_options_proto =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(face_condition_options.base_options)));
|
||||
options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get());
|
||||
|
||||
// Configure face landmarker graph.
|
||||
auto& face_landmarker_options =
|
||||
face_condition_options.face_landmarker_options;
|
||||
auto& face_landmarker_options_proto =
|
||||
*options_proto.mutable_conditioned_image_graph_options()
|
||||
->mutable_face_condition_type_options()
|
||||
->mutable_face_landmarker_graph_options();
|
||||
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(face_landmarker_options.base_options)));
|
||||
face_landmarker_options_proto.mutable_base_options()->Swap(
|
||||
base_options_proto.get());
|
||||
face_landmarker_options_proto.mutable_base_options()->set_use_stream_mode(
|
||||
false);
|
||||
|
||||
// Configure face detector options.
|
||||
auto* face_detector_graph_options =
|
||||
face_landmarker_options_proto.mutable_face_detector_graph_options();
|
||||
face_detector_graph_options->set_num_faces(face_landmarker_options.num_faces);
|
||||
face_detector_graph_options->set_min_detection_confidence(
|
||||
face_landmarker_options.min_face_detection_confidence);
|
||||
|
||||
// Configure face landmark detector options.
|
||||
face_landmarker_options_proto.set_min_tracking_confidence(
|
||||
face_landmarker_options.min_tracking_confidence);
|
||||
auto* face_landmarks_detector_graph_options =
|
||||
face_landmarker_options_proto
|
||||
.mutable_face_landmarks_detector_graph_options();
|
||||
face_landmarks_detector_graph_options->set_min_detection_confidence(
|
||||
face_landmarker_options.min_face_presence_confidence);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetDepthConditionOptionsToProto(
|
||||
DepthConditionOptions& depth_condition_options,
|
||||
ControlPluginGraphOptionsProto& options_proto) {
|
||||
// Configure face plugin model.
|
||||
auto plugin_base_options_proto =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(depth_condition_options.base_options)));
|
||||
options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get());
|
||||
|
||||
auto& image_segmenter_graph_options =
|
||||
*options_proto.mutable_conditioned_image_graph_options()
|
||||
->mutable_depth_condition_type_options()
|
||||
->mutable_image_segmenter_graph_options();
|
||||
|
||||
auto depth_base_options_proto =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(depth_condition_options.image_segmenter_options.base_options)));
|
||||
image_segmenter_graph_options.mutable_base_options()->Swap(
|
||||
depth_base_options_proto.get());
|
||||
image_segmenter_graph_options.mutable_base_options()->set_use_stream_mode(
|
||||
false);
|
||||
image_segmenter_graph_options.set_display_names_locale(
|
||||
depth_condition_options.image_segmenter_options.display_names_locale);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetEdgeConditionOptionsToProto(
|
||||
EdgeConditionOptions& edge_condition_options,
|
||||
ControlPluginGraphOptionsProto& options_proto) {
|
||||
auto plugin_base_options_proto =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(edge_condition_options.base_options)));
|
||||
options_proto.mutable_base_options()->Swap(plugin_base_options_proto.get());
|
||||
|
||||
auto& edge_options_proto =
|
||||
*options_proto.mutable_conditioned_image_graph_options()
|
||||
->mutable_edge_condition_type_options();
|
||||
edge_options_proto.set_threshold_1(edge_condition_options.threshold_1);
|
||||
edge_options_proto.set_threshold_2(edge_condition_options.threshold_2);
|
||||
edge_options_proto.set_aperture_size(edge_condition_options.aperture_size);
|
||||
edge_options_proto.set_l2_gradient(edge_condition_options.l2_gradient);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Helper holder struct of image generator graph options and condition type
|
||||
// index mapping.
|
||||
struct ImageGeneratorOptionsProtoAndConditionTypeIndex {
|
||||
std::unique_ptr<ImageGeneratorGraphOptionsProto> options_proto;
|
||||
std::unique_ptr<std::map<ConditionOptions::ConditionType, int>>
|
||||
condition_type_index;
|
||||
};
|
||||
|
||||
// Converts the user-facing ImageGeneratorOptions struct to the internal
|
||||
// ImageGeneratorOptions proto.
|
||||
absl::StatusOr<ImageGeneratorOptionsProtoAndConditionTypeIndex>
|
||||
ConvertImageGeneratorGraphOptionsProto(
|
||||
ImageGeneratorOptions* image_generator_options,
|
||||
ConditionOptions* condition_options) {
|
||||
ImageGeneratorOptionsProtoAndConditionTypeIndex
|
||||
options_proto_and_condition_index;
|
||||
|
||||
// Configure base image generator options.
|
||||
options_proto_and_condition_index.options_proto =
|
||||
std::make_unique<ImageGeneratorGraphOptionsProto>();
|
||||
auto& options_proto = *options_proto_and_condition_index.options_proto;
|
||||
options_proto.set_text2image_model_directory(
|
||||
image_generator_options->text2image_model_directory);
|
||||
if (image_generator_options->lora_weights_file_path.has_value()) {
|
||||
options_proto.mutable_lora_weights_file()->set_file_name(
|
||||
*image_generator_options->lora_weights_file_path);
|
||||
}
|
||||
|
||||
// Configure optional condition type options.
|
||||
if (condition_options != nullptr) {
|
||||
options_proto_and_condition_index.condition_type_index =
|
||||
std::make_unique<std::map<ConditionOptions::ConditionType, int>>();
|
||||
auto& condition_type_index =
|
||||
*options_proto_and_condition_index.condition_type_index;
|
||||
if (condition_options->face_condition_options.has_value()) {
|
||||
condition_type_index[ConditionOptions::FACE] =
|
||||
condition_type_index.size();
|
||||
auto& face_plugin_graph_options =
|
||||
*options_proto.add_control_plugin_graphs_options();
|
||||
RET_CHECK_OK(SetFaceConditionOptionsToProto(
|
||||
*condition_options->face_condition_options,
|
||||
face_plugin_graph_options));
|
||||
}
|
||||
if (condition_options->depth_condition_options.has_value()) {
|
||||
condition_type_index[ConditionOptions::DEPTH] =
|
||||
condition_type_index.size();
|
||||
auto& depth_plugin_graph_options =
|
||||
*options_proto.add_control_plugin_graphs_options();
|
||||
RET_CHECK_OK(SetDepthConditionOptionsToProto(
|
||||
*condition_options->depth_condition_options,
|
||||
depth_plugin_graph_options));
|
||||
}
|
||||
if (condition_options->edge_condition_options.has_value()) {
|
||||
condition_type_index[ConditionOptions::EDGE] =
|
||||
condition_type_index.size();
|
||||
auto& edge_plugin_graph_options =
|
||||
*options_proto.add_control_plugin_graphs_options();
|
||||
RET_CHECK_OK(SetEdgeConditionOptionsToProto(
|
||||
*condition_options->edge_condition_options,
|
||||
edge_plugin_graph_options));
|
||||
}
|
||||
if (condition_type_index.empty()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"At least one condition type must be set.");
|
||||
}
|
||||
}
|
||||
return options_proto_and_condition_index;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<ImageGenerator>> ImageGenerator::Create(
|
||||
std::unique_ptr<ImageGeneratorOptions> image_generator_options,
|
||||
std::unique_ptr<ConditionOptions> condition_options) {
|
||||
bool use_condition_image = condition_options != nullptr;
|
||||
ASSIGN_OR_RETURN(auto options_proto_and_condition_index,
|
||||
ConvertImageGeneratorGraphOptionsProto(
|
||||
image_generator_options.get(), condition_options.get()));
|
||||
std::unique_ptr<proto::ImageGeneratorGraphOptions>
|
||||
options_proto_for_condition_image_graphs_container;
|
||||
if (use_condition_image) {
|
||||
options_proto_for_condition_image_graphs_container =
|
||||
std::make_unique<proto::ImageGeneratorGraphOptions>();
|
||||
options_proto_for_condition_image_graphs_container->CopyFrom(
|
||||
*options_proto_and_condition_index.options_proto);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto image_generator,
|
||||
(core::VisionTaskApiFactory::Create<ImageGenerator,
|
||||
ImageGeneratorGraphOptionsProto>(
|
||||
CreateImageGeneratorGraphConfig(
|
||||
std::move(options_proto_and_condition_index.options_proto),
|
||||
use_condition_image),
|
||||
std::make_unique<tasks::core::MediaPipeBuiltinOpResolver>(),
|
||||
core::RunningMode::IMAGE,
|
||||
/*result_callback=*/nullptr)));
|
||||
image_generator->use_condition_image_ = use_condition_image;
|
||||
if (use_condition_image) {
|
||||
image_generator->condition_type_index_ =
|
||||
std::move(options_proto_and_condition_index.condition_type_index);
|
||||
ASSIGN_OR_RETURN(
|
||||
image_generator->condition_image_graphs_container_task_runner_,
|
||||
tasks::core::TaskRunner::Create(
|
||||
CreateConditionedImageGraphContainerConfig(
|
||||
std::move(options_proto_for_condition_image_graphs_container)),
|
||||
absl::make_unique<tasks::core::MediaPipeBuiltinOpResolver>()));
|
||||
}
|
||||
image_generator->init_timestamp_ = absl::Now();
|
||||
return image_generator;
|
||||
}
|
||||
|
||||
absl::StatusOr<Image> ImageGenerator::CreateConditionImage(
|
||||
Image source_condition_image,
|
||||
ConditionOptions::ConditionType condition_type) {
|
||||
if (condition_type_index_->find(condition_type) ==
|
||||
condition_type_index_->end()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The condition type is not created during initialization.");
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
condition_image_graphs_container_task_runner_->Process({
|
||||
{std::string(kSourceConditionImageName),
|
||||
MakePacket<Image>(std::move(source_condition_image))},
|
||||
{std::string(kSelectName),
|
||||
MakePacket<int>(condition_type_index_->at(condition_type))},
|
||||
}));
|
||||
return output_packets.at(std::string(kConditionImageName)).Get<Image>();
|
||||
}
|
||||
|
||||
absl::StatusOr<ImageGeneratorResult> ImageGenerator::Generate(
|
||||
const std::string& prompt, int iterations, int seed) {
|
||||
if (use_condition_image_) {
|
||||
return absl::InvalidArgumentError(
|
||||
"ImageGenerator is created to use with conditioned image.");
|
||||
}
|
||||
return RunIterations(prompt, iterations, seed, std::nullopt);
|
||||
}
|
||||
|
||||
absl::StatusOr<ImageGeneratorResult> ImageGenerator::Generate(
|
||||
const std::string& prompt, Image condition_image,
|
||||
ConditionOptions::ConditionType condition_type, int iterations, int seed) {
|
||||
if (!use_condition_image_) {
|
||||
return absl::InvalidArgumentError(
|
||||
"ImageGenerator is created to use without conditioned image.");
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto plugin_model_image,
|
||||
CreateConditionImage(condition_image, condition_type));
|
||||
return RunIterations(
|
||||
prompt, iterations, seed,
|
||||
ConditionInputs{plugin_model_image,
|
||||
condition_type_index_->at(condition_type)});
|
||||
}
|
||||
|
||||
absl::StatusOr<ImageGeneratorResult> ImageGenerator::RunIterations(
|
||||
const std::string& prompt, int steps, int rand_seed,
|
||||
std::optional<ConditionInputs> condition_inputs) {
|
||||
tasks::core::PacketMap output_packets;
|
||||
ImageGeneratorResult result;
|
||||
auto timestamp = (absl::Now() - init_timestamp_) / absl::Milliseconds(1);
|
||||
for (int i = 0; i < steps; ++i) {
|
||||
tasks::core::PacketMap input_packets;
|
||||
if (i == 0 && condition_inputs.has_value()) {
|
||||
input_packets[std::string(kConditionImageName)] =
|
||||
MakePacket<Image>(condition_inputs->condition_image)
|
||||
.At(Timestamp(timestamp));
|
||||
input_packets[std::string(kSelectName)] =
|
||||
MakePacket<int>(condition_inputs->select).At(Timestamp(timestamp));
|
||||
}
|
||||
input_packets[std::string(kStepsName)] =
|
||||
MakePacket<int>(steps).At(Timestamp(timestamp));
|
||||
input_packets[std::string(kIterationName)] =
|
||||
MakePacket<int>(i).At(Timestamp(timestamp));
|
||||
input_packets[std::string(kPromptName)] =
|
||||
MakePacket<std::string>(prompt).At(Timestamp(timestamp));
|
||||
input_packets[std::string(kRandSeedName)] =
|
||||
MakePacket<int>(rand_seed).At(Timestamp(timestamp));
|
||||
ASSIGN_OR_RETURN(output_packets, ProcessImageData(input_packets));
|
||||
timestamp += 1;
|
||||
}
|
||||
result.generated_image =
|
||||
output_packets.at(std::string(kImageOutName)).Get<Image>();
|
||||
if (condition_inputs.has_value()) {
|
||||
result.condition_image = condition_inputs->condition_image;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace image_generator
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
157
mediapipe/tasks/cc/vision/image_generator/image_generator.h
Normal file
157
mediapipe/tasks/cc/vision/image_generator/image_generator.h
Normal file
|
@ -0,0 +1,157 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/image_generator_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_generator {
|
||||
|
||||
// Options for drawing face landmarks image.
|
||||
struct FaceConditionOptions {
|
||||
// The base options for plugin model.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Face landmarker options used to detect face landmarks in the condition
|
||||
// image.
|
||||
face_landmarker::FaceLandmarkerOptions face_landmarker_options;
|
||||
};
|
||||
|
||||
// Options for detecting edges image.
|
||||
struct EdgeConditionOptions {
|
||||
// The base options for plugin model.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// These parameters are used to config Canny edge algorithm of OpenCV.
|
||||
// See more details:
|
||||
// https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
|
||||
|
||||
// First threshold for the hysteresis procedure.
|
||||
float threshold_1 = 100;
|
||||
|
||||
// Second threshold for the hysteresis procedure.
|
||||
float threshold_2 = 200;
|
||||
|
||||
// Aperture size for the Sobel operator. Typical range is 3~7.
|
||||
int aperture_size = 3;
|
||||
|
||||
// A flag, indicating whether a more accurate L2 norm should be used to
|
||||
// calculate the image gradient magnitude ( L2gradient=true ), or whether
|
||||
// the default L1 norm is enough ( L2gradient=false ).
|
||||
bool l2_gradient = false;
|
||||
};
|
||||
|
||||
// Options for detecting depth image.
|
||||
struct DepthConditionOptions {
|
||||
// The base options for plugin model.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Image segmenter options used to detect depth in the condition image.
|
||||
image_segmenter::ImageSegmenterOptions image_segmenter_options;
|
||||
};
|
||||
|
||||
struct ConditionOptions {
|
||||
enum ConditionType { FACE, EDGE, DEPTH };
|
||||
std::optional<FaceConditionOptions> face_condition_options;
|
||||
std::optional<EdgeConditionOptions> edge_condition_options;
|
||||
std::optional<DepthConditionOptions> depth_condition_options;
|
||||
};
|
||||
|
||||
// Note: The API is experimental and subject to change.
|
||||
// The options for configuring a mediapipe image generator task.
|
||||
struct ImageGeneratorOptions {
|
||||
// The text to image model directory storing the model weights.
|
||||
std::string text2image_model_directory;
|
||||
|
||||
// The path to LoRA weights file.
|
||||
std::optional<std::string> lora_weights_file_path;
|
||||
};
|
||||
|
||||
class ImageGenerator : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
||||
// Creates an ImageGenerator from the provided options.
|
||||
// image_generator_options: options to create the image generator.
|
||||
// condition_options: optional options if plugin models are used to generate
|
||||
// an image based on the condition image.
|
||||
static absl::StatusOr<std::unique_ptr<ImageGenerator>> Create(
|
||||
std::unique_ptr<ImageGeneratorOptions> image_generator_options,
|
||||
std::unique_ptr<ConditionOptions> condition_options = nullptr);
|
||||
|
||||
// Create the condition image of specified condition type from the source
|
||||
// condition image. Currently support face landmarks, depth image and edge
|
||||
// image as the condition image.
|
||||
absl::StatusOr<Image> CreateConditionImage(
|
||||
Image source_condition_image,
|
||||
ConditionOptions::ConditionType condition_type);
|
||||
|
||||
// Generates an image for iterations and the given random seed. Only valid
|
||||
// when the ImageGenerator is created without condition options.
|
||||
absl::StatusOr<ImageGeneratorResult> Generate(const std::string& prompt,
|
||||
int iterations, int seed = 0);
|
||||
|
||||
// Generates an image based on the condition image for iterations and the
|
||||
// given random seed.
|
||||
// A detailed introduction to the condition image:
|
||||
// https://ai.googleblog.com/2023/06/on-device-diffusion-plugins-for.html
|
||||
absl::StatusOr<ImageGeneratorResult> Generate(
|
||||
const std::string& prompt, Image condition_image,
|
||||
ConditionOptions::ConditionType condition_type, int iterations,
|
||||
int seed = 0);
|
||||
|
||||
private:
|
||||
struct ConditionInputs {
|
||||
Image condition_image;
|
||||
int select;
|
||||
};
|
||||
|
||||
bool use_condition_image_ = false;
|
||||
|
||||
absl::Time init_timestamp_;
|
||||
|
||||
std::unique_ptr<tasks::core::TaskRunner>
|
||||
condition_image_graphs_container_task_runner_;
|
||||
|
||||
std::unique_ptr<std::map<ConditionOptions::ConditionType, int>>
|
||||
condition_type_index_;
|
||||
|
||||
absl::StatusOr<ImageGeneratorResult> RunIterations(
|
||||
const std::string& prompt, int steps, int rand_seed,
|
||||
std::optional<ConditionInputs> condition_inputs);
|
||||
};
|
||||
|
||||
} // namespace image_generator
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_H_
|
|
@ -0,0 +1,361 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/tool/switch_container.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options.pb.h"
|
||||
#include "mediapipe/util/graph_builder_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_generator {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
|
||||
constexpr int kPluginsOutputSize = 512;
|
||||
constexpr absl::string_view kTensorsTag = "TENSORS";
|
||||
constexpr absl::string_view kImageTag = "IMAGE";
|
||||
constexpr absl::string_view kImageCpuTag = "IMAGE_CPU";
|
||||
constexpr absl::string_view kStepsTag = "STEPS";
|
||||
constexpr absl::string_view kIterationTag = "ITERATION";
|
||||
constexpr absl::string_view kPromptTag = "PROMPT";
|
||||
constexpr absl::string_view kRandSeedTag = "RAND_SEED";
|
||||
constexpr absl::string_view kPluginTensorsTag = "PLUGIN_TENSORS";
|
||||
constexpr absl::string_view kConditionImageTag = "CONDITION_IMAGE";
|
||||
constexpr absl::string_view kSelectTag = "SELECT";
|
||||
constexpr absl::string_view kMetadataFilename = "metadata";
|
||||
constexpr absl::string_view kLoraRankStr = "lora_rank";
|
||||
|
||||
struct ImageGeneratorInputs {
|
||||
Source<std::string> prompt;
|
||||
Source<int> steps;
|
||||
Source<int> iteration;
|
||||
Source<int> rand_seed;
|
||||
std::optional<Source<Image>> condition_image;
|
||||
std::optional<Source<int>> select_condition_type;
|
||||
};
|
||||
|
||||
struct ImageGeneratorOutputs {
|
||||
Source<Image> generated_image;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A container graph containing several ConditionedImageGraph from which to
|
||||
// choose specified condition type.
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// The source condition image, used to generate the condition image.
|
||||
// SELECT - int
|
||||
// The index of the selected conditioned image graph.
|
||||
// Outputs:
|
||||
// CONDITION_IMAGE - Image
|
||||
// The condition image created from the specified condition type.
|
||||
class ConditionedImageGraphContainer : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
auto& graph_options =
|
||||
*sc->MutableOptions<proto::ImageGeneratorGraphOptions>();
|
||||
auto source_condition_image = graph.In(kImageTag).Cast<Image>();
|
||||
auto select_condition_type = graph.In(kSelectTag).Cast<int>();
|
||||
auto& switch_container = graph.AddNode("SwitchContainer");
|
||||
auto& switch_options =
|
||||
switch_container.GetOptions<mediapipe::SwitchContainerOptions>();
|
||||
for (auto& control_plugin_graph_options :
|
||||
*graph_options.mutable_control_plugin_graphs_options()) {
|
||||
auto& node = *switch_options.add_contained_node();
|
||||
node.set_calculator(
|
||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraph");
|
||||
node.mutable_node_options()->Add()->PackFrom(
|
||||
control_plugin_graph_options.conditioned_image_graph_options());
|
||||
}
|
||||
source_condition_image >> switch_container.In(kImageTag);
|
||||
select_condition_type >> switch_container.In(kSelectTag);
|
||||
auto condition_image = switch_container.Out(kImageTag).Cast<Image>();
|
||||
condition_image >> graph.Out(kConditionImageTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_generator::ConditionedImageGraphContainer); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
// A helper graph to convert condition image to Tensor using the control plugin
|
||||
// model.
|
||||
// Inputs:
|
||||
// CONDITION_IMAGE - Image
|
||||
// The condition image input to the control plugin model.
|
||||
// Outputs:
|
||||
// PLUGIN_TENSORS - std::vector<Tensor>
|
||||
// The output tensors from the control plugin model. The tensors are used as
|
||||
// inputs to the image generation model.
|
||||
class ControlPluginGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
auto& graph_options =
|
||||
*sc->MutableOptions<proto::ControlPluginGraphOptions>();
|
||||
|
||||
auto condition_image = graph.In(kConditionImageTag).Cast<Image>();
|
||||
|
||||
// Convert Image to ImageFrame.
|
||||
auto& from_image = graph.AddNode("FromImageCalculator");
|
||||
condition_image >> from_image.In(kImageTag);
|
||||
auto image_frame = from_image.Out(kImageCpuTag);
|
||||
|
||||
// Convert ImageFrame to Tensor.
|
||||
auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator");
|
||||
auto& image_to_tensor_options =
|
||||
image_to_tensor.GetOptions<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
image_to_tensor_options.set_output_tensor_width(kPluginsOutputSize);
|
||||
image_to_tensor_options.set_output_tensor_height(kPluginsOutputSize);
|
||||
image_to_tensor_options.mutable_output_tensor_float_range()->set_min(-1);
|
||||
image_to_tensor_options.mutable_output_tensor_float_range()->set_max(1);
|
||||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||
image_frame >> image_to_tensor.In(kImageTag);
|
||||
|
||||
// Create the plugin model resource.
|
||||
ASSIGN_OR_RETURN(
|
||||
const core::ModelResources* plugin_model_resources,
|
||||
CreateModelResources(
|
||||
sc,
|
||||
std::make_unique<tasks::core::proto::ExternalFile>(
|
||||
*graph_options.mutable_base_options()->mutable_model_asset())));
|
||||
|
||||
// Add control plugin model inference.
|
||||
auto& plugins_inference =
|
||||
AddInference(*plugin_model_resources,
|
||||
graph_options.base_options().acceleration(), graph);
|
||||
image_to_tensor.Out(kTensorsTag) >> plugins_inference.In(kTensorsTag);
|
||||
// The plugins model is not runnable on OpenGL. Error message:
|
||||
// TfLiteGpuDelegate Prepare: Batch size mismatch, expected 1 but got 64
|
||||
// Node number 67 (TfLiteGpuDelegate) failed to prepare.
|
||||
plugins_inference.GetOptions<mediapipe::InferenceCalculatorOptions>()
|
||||
.mutable_delegate()
|
||||
->mutable_xnnpack();
|
||||
plugins_inference.Out(kTensorsTag).Cast<std::vector<Tensor>>() >>
|
||||
graph.Out(kPluginTensorsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_generator::ControlPluginGraph);
|
||||
|
||||
// A "mediapipe.tasks.vision.image_generator.ImageGeneratorGraph" performs image
|
||||
// generation from a text prompt, and a optional condition image.
|
||||
//
|
||||
// Inputs:
|
||||
// PROMPT - std::string
|
||||
// The prompt describing the image to be generated.
|
||||
// STEPS - int
|
||||
// The total steps to generate the image.
|
||||
// ITERATION - int
|
||||
// The current iteration in the generating steps. Must be less than STEPS.
|
||||
// RAND_SEED - int
|
||||
// The randaom seed input to the image generation model.
|
||||
// CONDITION_IMAGE - Image
|
||||
// The condition image used as a guidance for the image generation. Only
|
||||
// valid, if condtrol plugin graph options are set in the graph options.
|
||||
// SELECT - int
|
||||
// The index of the selected the control plugin graph.
|
||||
//
|
||||
// Outputs:
|
||||
// IMAGE - Image
|
||||
// The generated image.
|
||||
// STEPS - int @optional
|
||||
// The total steps to generate the image. The same as STEPS input.
|
||||
// ITERATION - int @optional
|
||||
// The current iteration in the generating steps. The same as ITERATION
|
||||
// input.
|
||||
class ImageGeneratorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
auto* subgraph_options =
|
||||
sc->MutableOptions<proto::ImageGeneratorGraphOptions>();
|
||||
std::optional<const core::ModelAssetBundleResources*> lora_resources;
|
||||
// Create LoRA weights asset bundle resources.
|
||||
if (subgraph_options->has_lora_weights_file()) {
|
||||
auto external_file = std::make_unique<tasks::core::proto::ExternalFile>();
|
||||
external_file->Swap(subgraph_options->mutable_lora_weights_file());
|
||||
ASSIGN_OR_RETURN(lora_resources, CreateModelAssetBundleResources(
|
||||
sc, std::move(external_file)));
|
||||
}
|
||||
std::optional<Source<Image>> condition_image;
|
||||
std::optional<Source<int>> select_condition_type;
|
||||
if (!subgraph_options->control_plugin_graphs_options().empty()) {
|
||||
condition_image = graph.In(kConditionImageTag).Cast<Image>();
|
||||
select_condition_type = graph.In(kSelectTag).Cast<int>();
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto outputs,
|
||||
BuildImageGeneratorGraph(
|
||||
*sc->MutableOptions<proto::ImageGeneratorGraphOptions>(),
|
||||
lora_resources,
|
||||
ImageGeneratorInputs{
|
||||
/*prompt=*/graph.In(kPromptTag).Cast<std::string>(),
|
||||
/*steps=*/graph.In(kStepsTag).Cast<int>(),
|
||||
/*iteration=*/graph.In(kIterationTag).Cast<int>(),
|
||||
/*rand_seed=*/graph.In(kRandSeedTag).Cast<int>(),
|
||||
/*condition_image*/ condition_image,
|
||||
/*select_condition_type*/ select_condition_type,
|
||||
},
|
||||
graph));
|
||||
outputs.generated_image >> graph.Out(kImageTag).Cast<Image>();
|
||||
|
||||
// Optional outputs to provide the current iteration.
|
||||
auto& pass_through = graph.AddNode("PassThroughCalculator");
|
||||
graph.In(kIterationTag) >> pass_through.In(0);
|
||||
graph.In(kStepsTag) >> pass_through.In(1);
|
||||
pass_through.Out(0) >> graph[Output<int>::Optional(kIterationTag)];
|
||||
pass_through.Out(1) >> graph[Output<int>::Optional(kStepsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
absl::StatusOr<ImageGeneratorOutputs> BuildImageGeneratorGraph(
|
||||
proto::ImageGeneratorGraphOptions& subgraph_options,
|
||||
std::optional<const core::ModelAssetBundleResources*> lora_resources,
|
||||
ImageGeneratorInputs inputs, Graph& graph) {
|
||||
auto& stable_diff = graph.AddNode("StableDiffusionIterateCalculator");
|
||||
if (inputs.condition_image.has_value()) {
|
||||
// Add switch container for multiple control plugin graphs.
|
||||
auto& switch_container = graph.AddNode("SwitchContainer");
|
||||
auto& switch_options =
|
||||
switch_container.GetOptions<mediapipe::SwitchContainerOptions>();
|
||||
for (auto& control_plugin_graph_options :
|
||||
*subgraph_options.mutable_control_plugin_graphs_options()) {
|
||||
auto& node = *switch_options.add_contained_node();
|
||||
node.set_calculator(
|
||||
"mediapipe.tasks.vision.image_generator.ControlPluginGraph");
|
||||
node.mutable_node_options()->Add()->PackFrom(
|
||||
control_plugin_graph_options);
|
||||
}
|
||||
*inputs.condition_image >> switch_container.In(kConditionImageTag);
|
||||
*inputs.select_condition_type >> switch_container.In(kSelectTag);
|
||||
auto plugin_tensors = switch_container.Out(kPluginTensorsTag);
|
||||
|
||||
// Additional diffusion plugins calculator to pass tensors to diffusion
|
||||
// iterator.
|
||||
auto& plugins_output = graph.AddNode("DiffusionPluginsOutputCalculator");
|
||||
plugin_tensors >> plugins_output.In(kTensorsTag);
|
||||
inputs.steps >> plugins_output.In(kStepsTag);
|
||||
inputs.iteration >> plugins_output.In(kIterationTag);
|
||||
plugins_output.Out(kTensorsTag) >> stable_diff.In(kPluginTensorsTag);
|
||||
}
|
||||
|
||||
inputs.prompt >> stable_diff.In(kPromptTag);
|
||||
inputs.steps >> stable_diff.In(kStepsTag);
|
||||
inputs.iteration >> stable_diff.In(kIterationTag);
|
||||
inputs.rand_seed >> stable_diff.In(kRandSeedTag);
|
||||
mediapipe::StableDiffusionIterateCalculatorOptions& options =
|
||||
stable_diff
|
||||
.GetOptions<mediapipe::StableDiffusionIterateCalculatorOptions>();
|
||||
options.set_base_seed(0);
|
||||
options.set_output_image_height(kPluginsOutputSize);
|
||||
options.set_output_image_width(kPluginsOutputSize);
|
||||
options.set_file_folder(subgraph_options.text2image_model_directory());
|
||||
options.set_show_every_n_iteration(100);
|
||||
options.set_emit_empty_packet(true);
|
||||
if (lora_resources.has_value()) {
|
||||
auto& lora_layer_weights_mapping =
|
||||
*options.mutable_lora_weights_layer_mapping();
|
||||
for (const auto& file_path : (*lora_resources)->ListFiles()) {
|
||||
auto basename = file::Basename(file_path);
|
||||
ASSIGN_OR_RETURN(auto file_content,
|
||||
(*lora_resources)->GetFile(std::string(file_path)));
|
||||
if (file_path == kMetadataFilename) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
ParseLoraMetadataAndConfigOptions(file_content, options));
|
||||
} else {
|
||||
lora_layer_weights_mapping[basename] =
|
||||
reinterpret_cast<uint64_t>(file_content.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& to_image = graph.AddNode("ToImageCalculator");
|
||||
stable_diff.Out(kImageTag) >> to_image.In(kImageCpuTag);
|
||||
|
||||
return {{to_image.Out(kImageTag).Cast<Image>()}};
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ParseLoraMetadataAndConfigOptions(
|
||||
absl::string_view contents,
|
||||
mediapipe::StableDiffusionIterateCalculatorOptions& options) {
|
||||
std::vector<absl::string_view> lines =
|
||||
absl::StrSplit(contents, '\n', absl::SkipEmpty());
|
||||
for (const auto& line : lines) {
|
||||
std::vector<absl::string_view> values = absl::StrSplit(line, ',');
|
||||
if (values[0] == kLoraRankStr) {
|
||||
int lora_rank;
|
||||
if (values.size() != 2 || !absl::SimpleAtoi(values[1], &lora_rank)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Error parsing LoRA weights metadata. ", line));
|
||||
}
|
||||
options.set_lora_rank(lora_rank);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_generator::ImageGeneratorGraph);
|
||||
|
||||
} // namespace image_generator
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_
|
||||
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_generator {
|
||||
|
||||
// The result of ImageGenerator task.
|
||||
struct ImageGeneratorResult {
|
||||
// The generated image.
|
||||
Image generated_image;
|
||||
|
||||
// The condition_image used in the plugin model, only available if the
|
||||
// condition type is set in ImageGeneratorOptions.
|
||||
std::optional<Image> condition_image = std::nullopt;
|
||||
};
|
||||
|
||||
} // namespace image_generator
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_GENERATOR_IMAGE_GENERATOR_RESULT_H_
|
52
mediapipe/tasks/cc/vision/image_generator/proto/BUILD
Normal file
52
mediapipe/tasks/cc/vision/image_generator/proto/BUILD
Normal file
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "conditioned_image_graph_options_proto",
|
||||
srcs = ["conditioned_image_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "control_plugin_graph_options_proto",
|
||||
srcs = ["control_plugin_graph_options.proto"],
|
||||
deps = [
|
||||
":conditioned_image_graph_options_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "image_generator_graph_options_proto",
|
||||
srcs = ["image_generator_graph_options.proto"],
|
||||
deps = [
|
||||
":control_plugin_graph_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,66 @@
|
|||
|
||||
/* Copyright 2023 The MediaPipe Authors.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.vision.image_generator.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
|
||||
option java_outer_classname = "ConditionedImageGraphOptionsProto";
|
||||
|
||||
message ConditionedImageGraphOptions {
|
||||
// For conditioned image graph based on face landmarks.
|
||||
message FaceConditionTypeOptions {
|
||||
// Options for the face landmarker used in the face landmarks type graph.
|
||||
face_landmarker.proto.FaceLandmarkerGraphOptions
|
||||
face_landmarker_graph_options = 1;
|
||||
}
|
||||
|
||||
// For conditioned image graph base on edges detection.
|
||||
message EdgeConditionTypeOptions {
|
||||
// These parameters are used to config Canny edge algorithm of OpenCV.
|
||||
// See more details:
|
||||
// https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
|
||||
|
||||
// First threshold for the hysteresis procedure.
|
||||
float threshold_1 = 1;
|
||||
|
||||
// Second threshold for the hysteresis procedure.
|
||||
float threshold_2 = 2;
|
||||
|
||||
// Aperture size for the Sobel operator. Typical range is 3~7.
|
||||
int32 aperture_size = 3;
|
||||
|
||||
// A flag, indicating whether a more accurate L2 norm should be used to
|
||||
// calculate the image gradient magnitude ( L2gradient=true ), or whether
|
||||
// the default L1 norm is enough ( L2gradient=false ).
|
||||
bool l2_gradient = 4;
|
||||
}
|
||||
|
||||
// For conditioned image graph base on depth map.
|
||||
message DepthConditionTypeOptions {
|
||||
// Options for the image segmenter used in the depth condition type graph.
|
||||
image_segmenter.proto.ImageSegmenterGraphOptions
|
||||
image_segmenter_graph_options = 1;
|
||||
}
|
||||
|
||||
// The options for configuring the conditioned image graph.
|
||||
oneof condition_type_options {
|
||||
FaceConditionTypeOptions face_condition_type_options = 2;
|
||||
EdgeConditionTypeOptions edge_condition_type_options = 3;
|
||||
DepthConditionTypeOptions depth_condition_type_options = 4;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.vision.image_generator.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
|
||||
option java_outer_classname = "ControlPluginGraphOptionsProto";
|
||||
|
||||
message ControlPluginGraphOptions {
|
||||
// The base options for the control plugin model.
|
||||
core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// The options for the ConditionedImageGraphOptions to generate control plugin
|
||||
// model input image.
|
||||
proto.ConditionedImageGraphOptions conditioned_image_graph_options = 2;
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.vision.image_generator.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/core/proto/external_file.proto";
|
||||
import "mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imagegenerator.proto";
|
||||
option java_outer_classname = "ImageGeneratorGraphOptionsProto";
|
||||
|
||||
message ImageGeneratorGraphOptions {
|
||||
// The directory containing the models weight of the text to image model.
|
||||
string text2image_model_directory = 1;
|
||||
|
||||
// An optional LoRA weights file. If set, the diffusion model will be created
|
||||
// with LoRA weights.
|
||||
core.proto.ExternalFile lora_weights_file = 2;
|
||||
|
||||
repeated proto.ControlPluginGraphOptions control_plugin_graphs_options = 3;
|
||||
}
|
|
@ -59,6 +59,22 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
_VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS = [
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:mesh_3d_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite",
|
||||
|
@ -249,6 +265,39 @@ EOF
|
|||
native_library = native_library,
|
||||
)
|
||||
|
||||
def mediapipe_tasks_vision_image_generator_aar(name, srcs, native_library):
|
||||
"""Builds medaipipe tasks vision image generator AAR.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
srcs: MediaPipe Vision Tasks' source files.
|
||||
native_library: The native library that contains image generator task's graph and calculators.
|
||||
"""
|
||||
|
||||
native.genrule(
|
||||
name = name + "tasks_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagegenerator">
|
||||
<uses-sdk
|
||||
android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
)
|
||||
|
||||
_mediapipe_tasks_aar(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
manifest = "AndroidManifest.xml",
|
||||
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_IMAGE_GENERATOR_JAVA_PROTO_LITE_TARGETS,
|
||||
native_library = native_library,
|
||||
)
|
||||
|
||||
def mediapipe_tasks_text_aar(name, srcs, native_library):
|
||||
"""Builds medaipipe tasks text AAR.
|
||||
|
||||
|
@ -344,6 +393,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l
|
|||
"//third_party:androidx_annotation",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@com_google_protobuf//:protobuf_javalite",
|
||||
] + select({
|
||||
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
|
||||
"//mediapipe/framework/port:disable_opencv": [],
|
||||
|
|
|
@ -413,6 +413,9 @@ load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl"
|
|||
|
||||
mediapipe_tasks_vision_aar(
|
||||
name = "tasks_vision",
|
||||
srcs = glob(["**/*.java"]),
|
||||
srcs = glob(
|
||||
["**/*.java"],
|
||||
exclude = ["imagegenerator/**"],
|
||||
),
|
||||
native_library = ":libmediapipe_tasks_vision_jni_lib",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagegenerator">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
# The native library of MediaPipe vision image generator tasks.
|
||||
cc_binary(
|
||||
name = "libmediapipe_tasks_vision_image_generator_jni.so",
|
||||
linkopts = [
|
||||
"-Wl,--no-undefined",
|
||||
"-Wl,--version-script,$(location //mediapipe/tasks/java:version_script.lds)",
|
||||
],
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_generator:image_generator_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/java:version_script.lds",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libmediapipe_tasks_vision_image_generator_jni_lib",
|
||||
srcs = [":libmediapipe_tasks_vision_image_generator_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "imagegenerator",
|
||||
srcs = [
|
||||
"ImageGenerator.java",
|
||||
"ImageGeneratorResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:conditioned_image_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:control_plugin_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_generator/proto:image_generator_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core_java",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:facelandmarker",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:imagesegmenter",
|
||||
"//third_party:any_java_proto",
|
||||
"//third_party:autovalue",
|
||||
"//third_party/java/protobuf:protobuf_lite",
|
||||
"@maven//:androidx_annotation_annotation",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_image_generator_aar")
|
||||
|
||||
mediapipe_tasks_vision_image_generator_aar(
|
||||
name = "tasks_vision_image_generator",
|
||||
srcs = glob(["**/*.java"]),
|
||||
native_library = ":libmediapipe_tasks_vision_image_generator_jni_lib",
|
||||
)
|
|
@ -0,0 +1,660 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagegenerator;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Bitmap;
|
||||
import android.util.Log;
|
||||
import androidx.annotation.Nullable;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.facelandmarker.FaceLandmarker.FaceLandmarkerOptions;
|
||||
import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ConditionedImageGraphOptionsProto.ConditionedImageGraphOptions;
|
||||
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ControlPluginGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.imagegenerator.proto.ImageGeneratorGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions;
|
||||
import com.google.protobuf.Any;
|
||||
import com.google.protobuf.ExtensionRegistryLite;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Performs image generation from a text prompt. */
|
||||
public final class ImageGenerator extends BaseVisionTaskApi {
|
||||
|
||||
private static final String STEPS_STREAM_NAME = "steps";
|
||||
private static final String ITERATION_STREAM_NAME = "iteration";
|
||||
private static final String PROMPT_STREAM_NAME = "prompt";
|
||||
private static final String RAND_SEED_STREAM_NAME = "rand_seed";
|
||||
private static final String SOURCE_CONDITION_IMAGE_STREAM_NAME = "source_condition_image";
|
||||
private static final String CONDITION_IMAGE_STREAM_NAME = "condition_image";
|
||||
private static final String SELECT_STREAM_NAME = "select";
|
||||
private static final int GENERATED_IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final int STEPS_OUT_STREAM_INDEX = 1;
|
||||
private static final int ITERATION_OUT_STREAM_INDEX = 2;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_generator.ImageGeneratorGraph";
|
||||
private static final String CONDITION_IMAGE_GRAPHS_CONTAINER_NAME =
|
||||
"mediapipe.tasks.vision.image_generator.ConditionedImageGraphContainer";
|
||||
private static final String TAG = "ImageGenerator";
|
||||
private TaskRunner conditionImageGraphsContainerTaskRunner;
|
||||
private Map<ConditionOptions.ConditionType, Integer> conditionTypeIndex;
|
||||
private boolean useConditionImage = false;
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageGenerator} instance from an {@link ImageGeneratorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param generatorOptions an {@link ImageGeneratorOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageGenerator} creation.
|
||||
*/
|
||||
public static ImageGenerator createFromOptions(
|
||||
Context context, ImageGeneratorOptions generatorOptions) {
|
||||
return createFromOptions(context, generatorOptions, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageGenerator} instance, from {@link ImageGeneratorOptions} and {@link
|
||||
* ConditionOptions}, if plugin models are used to generate an image based on the condition image.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param generatorOptions an {@link ImageGeneratorOptions} instance.
|
||||
* @param conditionOptions an {@link ConditionOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageGenerator} creation.
|
||||
*/
|
||||
public static ImageGenerator createFromOptions(
|
||||
Context context,
|
||||
ImageGeneratorOptions generatorOptions,
|
||||
@Nullable ConditionOptions conditionOptions) {
|
||||
List<String> inputStreams = new ArrayList<>();
|
||||
inputStreams.addAll(
|
||||
Arrays.asList(
|
||||
"STEPS:" + STEPS_STREAM_NAME,
|
||||
"ITERATION:" + ITERATION_STREAM_NAME,
|
||||
"PROMPT:" + PROMPT_STREAM_NAME,
|
||||
"RAND_SEED:" + RAND_SEED_STREAM_NAME));
|
||||
final boolean useConditionImage = conditionOptions != null;
|
||||
if (useConditionImage) {
|
||||
inputStreams.add("SELECT:" + SELECT_STREAM_NAME);
|
||||
inputStreams.add("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME);
|
||||
generatorOptions.conditionOptions = Optional.of(conditionOptions);
|
||||
}
|
||||
List<String> outputStreams =
|
||||
Arrays.asList("IMAGE:image_out", "STEPS:steps_out", "ITERATION:iteration_out");
|
||||
|
||||
OutputHandler<ImageGeneratorResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ImageGeneratorResult, Void>() {
|
||||
@Override
|
||||
@Nullable
|
||||
public ImageGeneratorResult convertToTaskResult(List<Packet> packets) {
|
||||
int iteration = PacketGetter.getInt32(packets.get(ITERATION_OUT_STREAM_INDEX));
|
||||
int steps = PacketGetter.getInt32(packets.get(STEPS_OUT_STREAM_INDEX));
|
||||
Log.i("ImageGenerator", "Iteration: " + iteration + ", Steps: " + steps);
|
||||
if (iteration != steps - 1) {
|
||||
return null;
|
||||
}
|
||||
Log.i("ImageGenerator", "processing generated image");
|
||||
Packet packet = packets.get(GENERATED_IMAGE_OUT_STREAM_INDEX);
|
||||
Bitmap generatedBitmap = AndroidPacketGetter.getBitmapFromRgb(packet);
|
||||
BitmapImageBuilder bitmapImageBuilder = new BitmapImageBuilder(generatedBitmap);
|
||||
return ImageGeneratorResult.create(
|
||||
bitmapImageBuilder.build(), packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
handler.setHandleTimestampBoundChanges(true);
|
||||
if (generatorOptions.resultListener().isPresent()) {
|
||||
ResultListener<ImageGeneratorResult, Void> resultListener =
|
||||
new ResultListener<ImageGeneratorResult, Void>() {
|
||||
@Override
|
||||
public void run(ImageGeneratorResult imageGeneratorResult, Void input) {
|
||||
generatorOptions.resultListener().get().run(imageGeneratorResult);
|
||||
}
|
||||
};
|
||||
handler.setResultListener(resultListener);
|
||||
}
|
||||
generatorOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageGeneratorOptions>builder()
|
||||
.setTaskName(ImageGenerator.class.getSimpleName())
|
||||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(inputStreams)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(generatorOptions)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
handler);
|
||||
ImageGenerator imageGenerator = new ImageGenerator(runner);
|
||||
if (useConditionImage) {
|
||||
imageGenerator.useConditionImage = true;
|
||||
inputStreams =
|
||||
Arrays.asList(
|
||||
"IMAGE:" + SOURCE_CONDITION_IMAGE_STREAM_NAME, "SELECT:" + SELECT_STREAM_NAME);
|
||||
outputStreams = Arrays.asList("CONDITION_IMAGE:" + CONDITION_IMAGE_STREAM_NAME);
|
||||
OutputHandler<ConditionImageResult, Void> conditionImageHandler = new OutputHandler<>();
|
||||
conditionImageHandler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ConditionImageResult, Void>() {
|
||||
@Override
|
||||
public ConditionImageResult convertToTaskResult(List<Packet> packets) {
|
||||
Packet packet = packets.get(0);
|
||||
return new AutoValue_ImageGenerator_ConditionImageResult(
|
||||
new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(packet)).build(),
|
||||
packet.getTimestamp() / MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
conditionImageHandler.setHandleTimestampBoundChanges(true);
|
||||
imageGenerator.conditionImageGraphsContainerTaskRunner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageGeneratorOptions>builder()
|
||||
.setTaskName(ImageGenerator.class.getSimpleName())
|
||||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||
.setTaskGraphName(CONDITION_IMAGE_GRAPHS_CONTAINER_NAME)
|
||||
.setInputStreams(inputStreams)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(generatorOptions)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
conditionImageHandler);
|
||||
imageGenerator.conditionTypeIndex = new HashMap<>();
|
||||
if (conditionOptions.faceConditionOptions().isPresent()) {
|
||||
imageGenerator.conditionTypeIndex.put(
|
||||
ConditionOptions.ConditionType.FACE, imageGenerator.conditionTypeIndex.size());
|
||||
}
|
||||
if (conditionOptions.edgeConditionOptions().isPresent()) {
|
||||
imageGenerator.conditionTypeIndex.put(
|
||||
ConditionOptions.ConditionType.EDGE, imageGenerator.conditionTypeIndex.size());
|
||||
}
|
||||
if (conditionOptions.depthConditionOptions().isPresent()) {
|
||||
imageGenerator.conditionTypeIndex.put(
|
||||
ConditionOptions.ConditionType.DEPTH, imageGenerator.conditionTypeIndex.size());
|
||||
}
|
||||
}
|
||||
return imageGenerator;
|
||||
}
|
||||
|
||||
private ImageGenerator(TaskRunner taskRunner) {
|
||||
super(taskRunner, RunningMode.IMAGE, "", "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an image for iterations and the given random seed. Only valid when the ImageGenerator
|
||||
* is created without condition options.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public ImageGeneratorResult generate(String prompt, int iterations, int seed) {
|
||||
return runIterations(prompt, iterations, seed, null, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an image based on the source image for iterations and the given random seed. Only
|
||||
* valid when the ImageGenerator is created with condition options.
|
||||
*
|
||||
* @param prompt The text prompt describing the image to be generated.
|
||||
* @param sourceConditionImage The source image used to create the condition image, which is used
|
||||
* as a guidance for the image generation.
|
||||
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
|
||||
* condition image.
|
||||
* @param iterations The total iterations to generate the image.
|
||||
* @param seed The random seed used during image generation.
|
||||
*/
|
||||
public ImageGeneratorResult generate(
|
||||
String prompt,
|
||||
MPImage sourceConditionImage,
|
||||
ConditionOptions.ConditionType conditionType,
|
||||
int iterations,
|
||||
int seed) {
|
||||
return runIterations(
|
||||
prompt,
|
||||
iterations,
|
||||
seed,
|
||||
createConditionImage(sourceConditionImage, conditionType),
|
||||
conditionTypeIndex.get(conditionType));
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the condition image of specified condition type from the source image. Currently support
|
||||
* face landmarks, depth image and edge image as the condition image.
|
||||
*
|
||||
* @param sourceConditionImage The source image used to create the condition image.
|
||||
* @param conditionType The {@link ConditionOptions.ConditionType} specifying the type of
|
||||
* condition image.
|
||||
*/
|
||||
public MPImage createConditionImage(
|
||||
MPImage sourceConditionImage, ConditionOptions.ConditionType conditionType) {
|
||||
if (!conditionTypeIndex.containsKey(conditionType)) {
|
||||
throw new IllegalArgumentException(
|
||||
"The condition type " + conditionType.name() + " is not created during initialization.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(
|
||||
SOURCE_CONDITION_IMAGE_STREAM_NAME,
|
||||
conditionImageGraphsContainerTaskRunner
|
||||
.getPacketCreator()
|
||||
.createImage(sourceConditionImage));
|
||||
inputPackets.put(
|
||||
SELECT_STREAM_NAME,
|
||||
conditionImageGraphsContainerTaskRunner
|
||||
.getPacketCreator()
|
||||
.createInt32(conditionTypeIndex.get(conditionType)));
|
||||
ConditionImageResult result =
|
||||
(ConditionImageResult) conditionImageGraphsContainerTaskRunner.process(inputPackets);
|
||||
return result.conditionImage();
|
||||
}
|
||||
|
||||
private ImageGeneratorResult runIterations(
|
||||
String prompt, int steps, int seed, @Nullable MPImage conditionImage, int select) {
|
||||
ImageGeneratorResult result = null;
|
||||
long timestamp = System.currentTimeMillis() * MICROSECONDS_PER_MILLISECOND;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
if (i == 0 && useConditionImage) {
|
||||
inputPackets.put(
|
||||
CONDITION_IMAGE_STREAM_NAME, runner.getPacketCreator().createImage(conditionImage));
|
||||
inputPackets.put(SELECT_STREAM_NAME, runner.getPacketCreator().createInt32(select));
|
||||
}
|
||||
inputPackets.put(PROMPT_STREAM_NAME, runner.getPacketCreator().createString(prompt));
|
||||
inputPackets.put(STEPS_STREAM_NAME, runner.getPacketCreator().createInt32(steps));
|
||||
inputPackets.put(ITERATION_STREAM_NAME, runner.getPacketCreator().createInt32(i));
|
||||
inputPackets.put(RAND_SEED_STREAM_NAME, runner.getPacketCreator().createInt32(seed));
|
||||
result = (ImageGeneratorResult) runner.process(inputPackets, timestamp++);
|
||||
}
|
||||
if (useConditionImage) {
|
||||
// Add condition image to the ImageGeneratorResult.
|
||||
return ImageGeneratorResult.create(
|
||||
result.generatedImage(), conditionImage, result.timestampMs());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Closes and cleans up the task runners. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
conditionImageGraphsContainerTaskRunner.close();
|
||||
}
|
||||
|
||||
/** A container class for the condition image. */
|
||||
@AutoValue
|
||||
protected abstract static class ConditionImageResult implements TaskResult {
|
||||
|
||||
public abstract MPImage conditionImage();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ImageGenerator}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageGeneratorOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link ImageGeneratorOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Sets the text to image model directory storing the model weights. */
|
||||
public abstract Builder setText2ImageModelDirectory(String modelDirectory);
|
||||
|
||||
/** Sets the path to LoRA weights file. */
|
||||
public abstract Builder setLoraWeightsFilePath(String loraWeightsFilePath);
|
||||
|
||||
public abstract Builder setResultListener(
|
||||
PureResultListener<ImageGeneratorResult> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract ImageGeneratorOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link ImageGeneratorOptions} instance. */
|
||||
public final ImageGeneratorOptions build() {
|
||||
return autoBuild();
|
||||
}
|
||||
}
|
||||
|
||||
abstract String text2ImageModelDirectory();
|
||||
|
||||
abstract Optional<String> loraWeightsFilePath();
|
||||
|
||||
abstract Optional<PureResultListener<ImageGeneratorResult>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
private Optional<ConditionOptions> conditionOptions;
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ImageGeneratorOptions.Builder()
|
||||
.setText2ImageModelDirectory("");
|
||||
}
|
||||
|
||||
/** Converts an {@link ImageGeneratorOptions} to a {@link Any} protobuf message. */
|
||||
@Override
|
||||
public Any convertToAnyProto() {
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder();
|
||||
if (conditionOptions != null && conditionOptions.isPresent()) {
|
||||
try {
|
||||
taskOptionsBuilder.mergeFrom(
|
||||
conditionOptions.get().convertToAnyProto().getValue(),
|
||||
ExtensionRegistryLite.getGeneratedRegistry());
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
Log.e(TAG, "Error converting ConditionOptions to proto. " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
taskOptionsBuilder.setText2ImageModelDirectory(text2ImageModelDirectory());
|
||||
if (loraWeightsFilePath().isPresent()) {
|
||||
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
||||
ExternalFileProto.ExternalFile.newBuilder();
|
||||
externalFileBuilder.setFileName(loraWeightsFilePath().get());
|
||||
taskOptionsBuilder.setLoraWeightsFile(externalFileBuilder.build());
|
||||
}
|
||||
return Any.newBuilder()
|
||||
.setTypeUrl(
|
||||
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
||||
.setValue(taskOptionsBuilder.build().toByteString())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for setting up the conditions types and the plugin models */
|
||||
@AutoValue
|
||||
public abstract static class ConditionOptions extends TaskOptions {
|
||||
|
||||
/** The supported condition type. */
|
||||
public enum ConditionType {
|
||||
FACE,
|
||||
EDGE,
|
||||
DEPTH
|
||||
}
|
||||
|
||||
/** Builder for {@link ConditionOptions}. At least one type of condition options must be set. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
public abstract Builder setFaceConditionOptions(FaceConditionOptions faceConditionOptions);
|
||||
|
||||
public abstract Builder setDepthConditionOptions(DepthConditionOptions depthConditionOptions);
|
||||
|
||||
public abstract Builder setEdgeConditionOptions(EdgeConditionOptions edgeConditionOptions);
|
||||
|
||||
abstract ConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link ConditionOptions} instance. */
|
||||
public final ConditionOptions build() {
|
||||
ConditionOptions options = autoBuild();
|
||||
if (!options.faceConditionOptions().isPresent()
|
||||
&& !options.depthConditionOptions().isPresent()
|
||||
&& !options.edgeConditionOptions().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"At least one of `faceConditionOptions`, `depthConditionOptions` and"
|
||||
+ " `edgeConditionOptions` must be set.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract Optional<FaceConditionOptions> faceConditionOptions();
|
||||
|
||||
abstract Optional<DepthConditionOptions> depthConditionOptions();
|
||||
|
||||
abstract Optional<EdgeConditionOptions> edgeConditionOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions.Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an {@link ImageGeneratorOptions} to a {@link CalculatorOptions} protobuf message.
|
||||
*/
|
||||
@Override
|
||||
public Any convertToAnyProto() {
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageGeneratorGraphOptionsProto.ImageGeneratorGraphOptions.newBuilder();
|
||||
if (faceConditionOptions().isPresent()) {
|
||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(faceConditionOptions().get().baseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setFaceConditionTypeOptions(faceConditionOptions().get().convertToProto())
|
||||
.build())
|
||||
.build());
|
||||
}
|
||||
if (edgeConditionOptions().isPresent()) {
|
||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(edgeConditionOptions().get().baseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setEdgeConditionTypeOptions(edgeConditionOptions().get().convertToProto())
|
||||
.build())
|
||||
.build());
|
||||
if (depthConditionOptions().isPresent()) {
|
||||
taskOptionsBuilder.addControlPluginGraphsOptions(
|
||||
ControlPluginGraphOptionsProto.ControlPluginGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
convertBaseOptionsToProto(depthConditionOptions().get().baseOptions()))
|
||||
.setConditionedImageGraphOptions(
|
||||
ConditionedImageGraphOptions.newBuilder()
|
||||
.setDepthConditionTypeOptions(
|
||||
depthConditionOptions().get().convertToProto())
|
||||
.build())
|
||||
.build());
|
||||
}
|
||||
}
|
||||
return Any.newBuilder()
|
||||
.setTypeUrl(
|
||||
"type.googleapis.com/mediapipe.tasks.vision.image_generator.proto.ImageGeneratorGraphOptions")
|
||||
.setValue(taskOptionsBuilder.build().toByteString())
|
||||
.build();
|
||||
}
|
||||
|
||||
/** Options for drawing face landmarks image. */
|
||||
@AutoValue
|
||||
public abstract static class FaceConditionOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link FaceConditionOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/* {@link FaceLandmarkerOptions} used to detect face landmarks in the source image. */
|
||||
public abstract Builder setFaceLandmarkerOptions(
|
||||
FaceLandmarkerOptions faceLandmarkerOptions);
|
||||
|
||||
abstract FaceConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link FaceConditionOptions} instance. */
|
||||
public final FaceConditionOptions build() {
|
||||
return autoBuild();
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract FaceLandmarkerOptions faceLandmarkerOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_FaceConditionOptions.Builder();
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.FaceConditionTypeOptions convertToProto() {
|
||||
return ConditionedImageGraphOptions.FaceConditionTypeOptions.newBuilder()
|
||||
.setFaceLandmarkerGraphOptions(
|
||||
FaceLandmarkerGraphOptions.newBuilder()
|
||||
.mergeFrom(
|
||||
faceLandmarkerOptions()
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(FaceLandmarkerGraphOptions.ext))
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for detecting depth image. */
|
||||
@AutoValue
|
||||
public abstract static class DepthConditionOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link DepthConditionOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** {@link ImageSegmenterOptions} used to detect depth image from the source image. */
|
||||
public abstract Builder setImageSegmenterOptions(
|
||||
ImageSegmenterOptions imageSegmenterOptions);
|
||||
|
||||
abstract DepthConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link DepthConditionOptions} instance. */
|
||||
public final DepthConditionOptions build() {
|
||||
DepthConditionOptions options = autoBuild();
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract ImageSegmenterOptions imageSegmenterOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_DepthConditionOptions.Builder();
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.DepthConditionTypeOptions convertToProto() {
|
||||
return ConditionedImageGraphOptions.DepthConditionTypeOptions.newBuilder()
|
||||
.setImageSegmenterGraphOptions(
|
||||
imageSegmenterOptions()
|
||||
.convertToCalculatorOptionsProto()
|
||||
.getExtension(ImageSegmenterGraphOptions.ext))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Options for detecting edge image. */
|
||||
@AutoValue
|
||||
public abstract static class EdgeConditionOptions {
|
||||
|
||||
/**
|
||||
* Builder for {@link EdgeConditionOptions}.
|
||||
*
|
||||
* <p>These parameters are used to config Canny edge algorithm of OpenCV.
|
||||
*
|
||||
* <p>See more details:
|
||||
* https://docs.opencv.org/3.4/dd/d1a/group__imgproc__feature.html#ga04723e007ed888ddf11d9ba04e2232de
|
||||
*/
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/** Set the base options for plugin model. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/** First threshold for the hysteresis procedure. */
|
||||
public abstract Builder setThreshold1(Float threshold1);
|
||||
|
||||
/** Second threshold for the hysteresis procedure. */
|
||||
public abstract Builder setThreshold2(Float threshold2);
|
||||
|
||||
/** Aperture size for the Sobel operator. Typical range is 3~7. */
|
||||
public abstract Builder setApertureSize(Integer apertureSize);
|
||||
|
||||
/**
|
||||
* flag, indicating whether a more accurate L2 norm should be used to calculate the image
|
||||
* gradient magnitude ( L2gradient=true ), or whether the default L1 norm is enough (
|
||||
* L2gradient=false ).
|
||||
*/
|
||||
public abstract Builder setL2Gradient(Boolean l2Gradient);
|
||||
|
||||
abstract EdgeConditionOptions autoBuild();
|
||||
|
||||
/** Validates and builds the {@link EdgeConditionOptions} instance. */
|
||||
public final EdgeConditionOptions build() {
|
||||
return autoBuild();
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract Float threshold1();
|
||||
|
||||
abstract Float threshold2();
|
||||
|
||||
abstract Integer apertureSize();
|
||||
|
||||
abstract Boolean l2Gradient();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageGenerator_ConditionOptions_EdgeConditionOptions.Builder()
|
||||
.setThreshold1(100f)
|
||||
.setThreshold2(200f)
|
||||
.setApertureSize(3)
|
||||
.setL2Gradient(false);
|
||||
}
|
||||
|
||||
ConditionedImageGraphOptions.EdgeConditionTypeOptions convertToProto() {
|
||||
return ConditionedImageGraphOptions.EdgeConditionTypeOptions.newBuilder()
|
||||
.setThreshold1(threshold1())
|
||||
.setThreshold2(threshold2())
|
||||
.setApertureSize(apertureSize())
|
||||
.setL2Gradient(l2Gradient())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagegenerator;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Represents the image generation results generated by {@link ImageGenerator}. */
|
||||
@AutoValue
|
||||
public abstract class ImageGeneratorResult implements TaskResult {
|
||||
|
||||
/** Create an {@link ImageGeneratorResult} instance from the generated image. */
|
||||
public static ImageGeneratorResult create(
|
||||
MPImage generatedImage, MPImage conditionImage, long timestampMs) {
|
||||
return new AutoValue_ImageGeneratorResult(
|
||||
generatedImage, Optional.of(conditionImage), timestampMs);
|
||||
}
|
||||
|
||||
/** Create an {@link ImageGeneratorResult} instance from the generated image. */
|
||||
public static ImageGeneratorResult create(MPImage generatedImage, long timestampMs) {
|
||||
return new AutoValue_ImageGeneratorResult(generatedImage, Optional.empty(), timestampMs);
|
||||
}
|
||||
|
||||
public abstract MPImage generatedImage();
|
||||
|
||||
public abstract Optional<MPImage> conditionImage();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
Loading…
Reference in New Issue
Block a user