internal update.

PiperOrigin-RevId: 561995330
This commit is contained in:
MediaPipe Team 2023-09-01 10:51:42 -07:00 committed by Copybara-Service
parent 823493ee82
commit ceb8cd3c78

View File

@ -75,6 +75,11 @@ DiffuserModelType ToDiffuserModelType(
// The iteration of the current run.
// PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
// The output tensor vector of the diffusion plugins model.
// PLUGIN_STRENGTH - float @Optional
// The strength of the plugin tensors.
// SHOW_RESULT - bool @Optional
// Whether to show the diffusion result at the current step, regardless
// of what show_every_n_iteration is set to.
//
// Outputs:
// IMAGE - mediapipe::ImageFrame
@ -104,9 +109,12 @@ class StableDiffusionIterateCalculator : public Node {
kOptionsIn{"OPTIONS"};
static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
"PLUGIN_TENSORS"};
static constexpr Input<float>::Optional kPluginStrengthIn{"PLUGIN_STRENGTH"};
static constexpr Input<bool>::Optional kShowResultIn{"SHOW_RESULT"};
static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
kPlugInTensorsIn, kOptionsIn, kImageOut);
kPlugInTensorsIn, kPluginStrengthIn, kShowResultIn,
kOptionsIn, kImageOut);
~StableDiffusionIterateCalculator() {
if (context_) DiffuserDelete();
@ -237,11 +245,17 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
const int steps = *kStepsIn(cc);
const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
: options.base_seed();
float plugins_strength = options.plugins_strength();
if (kPluginStrengthIn(cc).IsConnected() && !kPluginStrengthIn(cc).IsEmpty()) {
plugins_strength = kPluginStrengthIn(cc).Get();
RET_CHECK(plugins_strength >= 0.0f || plugins_strength <= 1.0f)
<< "The value of plugins_strength must be in the range of [0, 1].";
}
if (kIterationIn(cc).IsEmpty()) {
const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
options.plugins_strength(), &plugin_tensors));
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, plugins_strength,
&plugin_tensors));
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
options.output_image_height());
@ -255,14 +269,19 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
if (iteration == 0) {
const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
options.plugins_strength(), &plugin_tensors));
plugins_strength, &plugin_tensors));
}
RET_CHECK(DiffuserIterate(steps, iteration));
bool force_show_result = kShowResultIn(cc).IsConnected() &&
!kShowResultIn(cc).IsEmpty() &&
kShowResultIn(cc).Get();
bool show_result = force_show_result ||
(iteration + 1) % show_every_n_iteration_ == 0 ||
iteration == steps - 1;
// Decode the output and send out the image for visualization.
if ((iteration + 1) % show_every_n_iteration_ == 0 ||
iteration == steps - 1) {
if (show_result) {
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
options.output_image_height());
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));