internal update.
PiperOrigin-RevId: 561995330
This commit is contained in:
parent
823493ee82
commit
ceb8cd3c78
|
@ -75,6 +75,11 @@ DiffuserModelType ToDiffuserModelType(
|
||||||
// The iteration of the current run.
|
// The iteration of the current run.
|
||||||
// PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
|
// PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
|
||||||
// The output tensor vector of the diffusion plugins model.
|
// The output tensor vector of the diffusion plugins model.
|
||||||
|
// PLUGIN_STRENGTH - float @Optional
|
||||||
|
// The strength of the plugin tensors.
|
||||||
|
// SHOW_RESULT - bool @Optional
|
||||||
|
// Whether to show the diffusion result at the current step, regardless
|
||||||
|
// of what show_every_n_iteration is set to.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// IMAGE - mediapipe::ImageFrame
|
// IMAGE - mediapipe::ImageFrame
|
||||||
|
@ -104,9 +109,12 @@ class StableDiffusionIterateCalculator : public Node {
|
||||||
kOptionsIn{"OPTIONS"};
|
kOptionsIn{"OPTIONS"};
|
||||||
static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
|
static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
|
||||||
"PLUGIN_TENSORS"};
|
"PLUGIN_TENSORS"};
|
||||||
|
static constexpr Input<float>::Optional kPluginStrengthIn{"PLUGIN_STRENGTH"};
|
||||||
|
static constexpr Input<bool>::Optional kShowResultIn{"SHOW_RESULT"};
|
||||||
static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
|
static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
|
MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
|
||||||
kPlugInTensorsIn, kOptionsIn, kImageOut);
|
kPlugInTensorsIn, kPluginStrengthIn, kShowResultIn,
|
||||||
|
kOptionsIn, kImageOut);
|
||||||
|
|
||||||
~StableDiffusionIterateCalculator() {
|
~StableDiffusionIterateCalculator() {
|
||||||
if (context_) DiffuserDelete();
|
if (context_) DiffuserDelete();
|
||||||
|
@ -237,11 +245,17 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
||||||
const int steps = *kStepsIn(cc);
|
const int steps = *kStepsIn(cc);
|
||||||
const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
|
const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
|
||||||
: options.base_seed();
|
: options.base_seed();
|
||||||
|
float plugins_strength = options.plugins_strength();
|
||||||
|
if (kPluginStrengthIn(cc).IsConnected() && !kPluginStrengthIn(cc).IsEmpty()) {
|
||||||
|
plugins_strength = kPluginStrengthIn(cc).Get();
|
||||||
|
RET_CHECK(plugins_strength >= 0.0f || plugins_strength <= 1.0f)
|
||||||
|
<< "The value of plugins_strength must be in the range of [0, 1].";
|
||||||
|
}
|
||||||
|
|
||||||
if (kIterationIn(cc).IsEmpty()) {
|
if (kIterationIn(cc).IsEmpty()) {
|
||||||
const auto plugin_tensors = GetPluginTensors(cc);
|
const auto plugin_tensors = GetPluginTensors(cc);
|
||||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, plugins_strength,
|
||||||
options.plugins_strength(), &plugin_tensors));
|
&plugin_tensors));
|
||||||
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
||||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||||
options.output_image_height());
|
options.output_image_height());
|
||||||
|
@ -255,14 +269,19 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
||||||
if (iteration == 0) {
|
if (iteration == 0) {
|
||||||
const auto plugin_tensors = GetPluginTensors(cc);
|
const auto plugin_tensors = GetPluginTensors(cc);
|
||||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||||
options.plugins_strength(), &plugin_tensors));
|
plugins_strength, &plugin_tensors));
|
||||||
}
|
}
|
||||||
|
|
||||||
RET_CHECK(DiffuserIterate(steps, iteration));
|
RET_CHECK(DiffuserIterate(steps, iteration));
|
||||||
|
|
||||||
|
bool force_show_result = kShowResultIn(cc).IsConnected() &&
|
||||||
|
!kShowResultIn(cc).IsEmpty() &&
|
||||||
|
kShowResultIn(cc).Get();
|
||||||
|
bool show_result = force_show_result ||
|
||||||
|
(iteration + 1) % show_every_n_iteration_ == 0 ||
|
||||||
|
iteration == steps - 1;
|
||||||
// Decode the output and send out the image for visualization.
|
// Decode the output and send out the image for visualization.
|
||||||
if ((iteration + 1) % show_every_n_iteration_ == 0 ||
|
if (show_result) {
|
||||||
iteration == steps - 1) {
|
|
||||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||||
options.output_image_height());
|
options.output_image_height());
|
||||||
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
|
RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
|
||||||
|
|
Loading…
Reference in New Issue
Block a user