Internal changes

PiperOrigin-RevId: 561398473
This commit is contained in:
Jiuqiang Tang 2023-08-30 11:24:25 -07:00 committed by Copybara-Service
parent 5434b840f6
commit f60da2120d
2 changed files with 13 additions and 12 deletions

View File

@ -61,7 +61,6 @@ typedef struct {
int image_width; int image_width;
int image_height; int image_height;
int run_unet_with_plugins; int run_unet_with_plugins;
float plugins_strength;
DiffuserEnvironmentOptions env_options; DiffuserEnvironmentOptions env_options;
} DiffuserConfig; } DiffuserConfig;
@ -76,7 +75,7 @@ typedef struct {
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
const char*, int, int, const void*); const char*, int, int, float, const void*);
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT

View File

@ -141,7 +141,7 @@ class StableDiffusionIterateCalculator : public Node {
dlsym(handle_, "DiffuserCreate")); dlsym(handle_, "DiffuserCreate"));
RET_CHECK(create_ptr_) << dlerror(); RET_CHECK(create_ptr_) << dlerror();
reset_ptr_ = reset_ptr_ =
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int, reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int, float,
const void*)>(dlsym(handle_, "DiffuserReset")); const void*)>(dlsym(handle_, "DiffuserReset"));
RET_CHECK(reset_ptr_) << dlerror(); RET_CHECK(reset_ptr_) << dlerror();
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>( iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
@ -159,9 +159,9 @@ class StableDiffusionIterateCalculator : public Node {
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) { DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
return (*create_ptr_)(a); return (*create_ptr_)(a);
} }
bool DiffuserReset(const char* a, int b, int c, bool DiffuserReset(const char* a, int b, int c, float d,
const std::vector<DiffuserPluginTensor>* d) { const std::vector<DiffuserPluginTensor>* e) {
return (*reset_ptr_)(context_, a, b, c, d); return (*reset_ptr_)(context_, a, b, c, d, e);
} }
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); } bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); } bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
@ -170,7 +170,8 @@ class StableDiffusionIterateCalculator : public Node {
void* handle_ = nullptr; void* handle_ = nullptr;
DiffuserContext* context_ = nullptr; DiffuserContext* context_ = nullptr;
DiffuserContext* (*create_ptr_)(const DiffuserConfig*); DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*); int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float,
const void*);
int (*iterate_ptr_)(DiffuserContext*, int, int); int (*iterate_ptr_)(DiffuserContext*, int, int);
int (*decode_ptr_)(DiffuserContext*, uint8_t*); int (*decode_ptr_)(DiffuserContext*, uint8_t*);
void (*delete_ptr_)(DiffuserContext*); void (*delete_ptr_)(DiffuserContext*);
@ -221,8 +222,8 @@ absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()), .priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
.performance_hint = kDiffuserPerformanceHintHigh, .performance_hint = kDiffuserPerformanceHintHigh,
}; };
config.plugins_strength = options.plugins_strength(); RET_CHECK(options.plugins_strength() >= 0.0f ||
RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f) options.plugins_strength() <= 1.0f)
<< "The value of plugins_strength must be in the range of [0, 1]."; << "The value of plugins_strength must be in the range of [0, 1].";
context_ = DiffuserCreate(&config); context_ = DiffuserCreate(&config);
RET_CHECK(context_); RET_CHECK(context_);
@ -239,7 +240,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
if (kIterationIn(cc).IsEmpty()) { if (kIterationIn(cc).IsEmpty()) {
const auto plugin_tensors = GetPluginTensors(cc); const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors)); RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
options.plugins_strength(), &plugin_tensors));
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i)); for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(), ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
options.output_image_height()); options.output_image_height());
@ -252,8 +254,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
// Extract text embedding on first iteration. // Extract text embedding on first iteration.
if (iteration == 0) { if (iteration == 0) {
const auto plugin_tensors = GetPluginTensors(cc); const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK( RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors)); options.plugins_strength(), &plugin_tensors));
} }
RET_CHECK(DiffuserIterate(steps, iteration)); RET_CHECK(DiffuserIterate(steps, iteration));