diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h index 522f0430c..85738b80b 100644 --- a/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h @@ -61,7 +61,6 @@ typedef struct { int image_width; int image_height; int run_unet_with_plugins; - float plugins_strength; DiffuserEnvironmentOptions env_options; } DiffuserConfig; @@ -76,7 +75,7 @@ typedef struct { DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT - const char*, int, int, const void*); + const char*, int, int, float, const void*); DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT diff --git a/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc index 77b24a715..2df731611 100644 --- a/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.cc @@ -141,7 +141,7 @@ class StableDiffusionIterateCalculator : public Node { dlsym(handle_, "DiffuserCreate")); RET_CHECK(create_ptr_) << dlerror(); reset_ptr_ = - reinterpret_cast(dlsym(handle_, "DiffuserReset")); RET_CHECK(reset_ptr_) << dlerror(); iterate_ptr_ = reinterpret_cast( @@ -159,9 +159,9 @@ class StableDiffusionIterateCalculator : public Node { DiffuserContext* DiffuserCreate(const DiffuserConfig* a) { return (*create_ptr_)(a); } - bool DiffuserReset(const char* a, int b, int c, - const std::vector* d) { - return (*reset_ptr_)(context_, a, b, c, d); + bool DiffuserReset(const char* a, int b, int c, float d, + const std::vector* e) { + return (*reset_ptr_)(context_, a, b, c, d, e); } bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); } bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); } @@ -170,7 +170,8 @@ class StableDiffusionIterateCalculator : public Node { void* handle_ = nullptr; DiffuserContext* context_ = nullptr; DiffuserContext* (*create_ptr_)(const DiffuserConfig*); - int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*); + int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float, + const void*); int (*iterate_ptr_)(DiffuserContext*, int, int); int (*decode_ptr_)(DiffuserContext*, uint8_t*); void (*delete_ptr_)(DiffuserContext*); @@ -221,8 +222,8 @@ absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) { .priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()), .performance_hint = kDiffuserPerformanceHintHigh, }; - config.plugins_strength = options.plugins_strength(); - RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f) + RET_CHECK(options.plugins_strength() >= 0.0f || + options.plugins_strength() <= 1.0f) << "The value of plugins_strength must be in the range of [0, 1]."; context_ = DiffuserCreate(&config); RET_CHECK(context_); @@ -239,7 +240,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) { if (kIterationIn(cc).IsEmpty()) { const auto plugin_tensors = GetPluginTensors(cc); - RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors)); + RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, + options.plugins_strength(), &plugin_tensors)); for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i)); ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(), options.output_image_height()); @@ -252,8 +254,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) { // Extract text embedding on first iteration. if (iteration == 0) { const auto plugin_tensors = GetPluginTensors(cc); - RET_CHECK( - DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors)); + RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, + options.plugins_strength(), &plugin_tensors)); } RET_CHECK(DiffuserIterate(steps, iteration));