internal update.
PiperOrigin-RevId: 561995330
This commit is contained in:
parent
823493ee82
commit
ceb8cd3c78
|
@ -75,6 +75,11 @@ DiffuserModelType ToDiffuserModelType(
|
|||
// The iteration of the current run.
|
||||
// PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
|
||||
// The output tensor vector of the diffusion plugins model.
|
||||
// PLUGIN_STRENGTH - float @Optional
|
||||
// The strength of the plugin tensors.
|
||||
// SHOW_RESULT - bool @Optional
|
||||
// Whether to show the diffusion result at the current step, regardless
|
||||
// of what show_every_n_iteration is set to.
|
||||
//
|
||||
// Outputs:
|
||||
// IMAGE - mediapipe::ImageFrame
|
||||
|
@ -104,9 +109,12 @@ class StableDiffusionIterateCalculator : public Node {
|
|||
kOptionsIn{"OPTIONS"};
|
||||
static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
|
||||
"PLUGIN_TENSORS"};
|
||||
static constexpr Input<float>::Optional kPluginStrengthIn{"PLUGIN_STRENGTH"};
|
||||
static constexpr Input<bool>::Optional kShowResultIn{"SHOW_RESULT"};
|
||||
static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
|
||||
kPlugInTensorsIn, kOptionsIn, kImageOut);
|
||||
kPlugInTensorsIn, kPluginStrengthIn, kShowResultIn,
|
||||
kOptionsIn, kImageOut);
|
||||
|
||||
~StableDiffusionIterateCalculator() {
|
||||
if (context_) DiffuserDelete();
|
||||
|
@ -237,11 +245,17 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
|||
const int steps = *kStepsIn(cc);
|
||||
const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
|
||||
: options.base_seed();
|
||||
float plugins_strength = options.plugins_strength();
|
||||
if (kPluginStrengthIn(cc).IsConnected() && !kPluginStrengthIn(cc).IsEmpty()) {
|
||||
plugins_strength = kPluginStrengthIn(cc).Get();
|
||||
RET_CHECK(plugins_strength >= 0.0f || plugins_strength <= 1.0f)
|
||||
<< "The value of plugins_strength must be in the range of [0, 1].";
|
||||
}
|
||||
|
||||
if (kIterationIn(cc).IsEmpty()) {
|
||||
const auto plugin_tensors = GetPluginTensors(cc);
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||
options.plugins_strength(), &plugin_tensors));
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, plugins_strength,
|
||||
&plugin_tensors));
|
||||
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||
options.output_image_height());
|
||||
|
@ -255,14 +269,19 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
|||
if (iteration == 0) {
|
||||
const auto plugin_tensors = GetPluginTensors(cc);
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||
options.plugins_strength(), &plugin_tensors));
|
||||
plugins_strength, &plugin_tensors));
|
||||
}
|
||||
|
||||
RET_CHECK(DiffuserIterate(steps, iteration));
|
||||
|
||||
bool force_show_result = kShowResultIn(cc).IsConnected() &&
|
||||
!kShowResultIn(cc).IsEmpty() &&
|
||||
kShowResultIn(cc).Get();
|
||||
bool show_result = force_show_result ||
|
||||
(iteration + 1) % show_every_n_iteration_ == 0 ||
|
||||
iteration == steps - 1;
|
||||
// Decode the output and send out the image for visualization.
|
||||
if ((iteration + 1) % show_every_n_iteration_ == 0 ||
|
||||
iteration == steps - 1) {
|
||||
if (show_result) {
|
||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||
options.output_image_height());
|
||||
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
|
||||
|
|
Loading…
Reference in New Issue
Block a user