Internal changes

PiperOrigin-RevId: 561398473
This commit is contained in:
Jiuqiang Tang 2023-08-30 11:24:25 -07:00 committed by Copybara-Service
parent 5434b840f6
commit f60da2120d
2 changed files with 13 additions and 12 deletions

View File

@ -61,7 +61,6 @@ typedef struct {
int image_width;
int image_height;
int run_unet_with_plugins;
float plugins_strength;
DiffuserEnvironmentOptions env_options;
} DiffuserConfig;
@ -76,7 +75,7 @@ typedef struct {
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
const char*, int, int, const void*);
const char*, int, int, float, const void*);
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT

View File

@ -141,7 +141,7 @@ class StableDiffusionIterateCalculator : public Node {
dlsym(handle_, "DiffuserCreate"));
RET_CHECK(create_ptr_) << dlerror();
reset_ptr_ =
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int,
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int, float,
const void*)>(dlsym(handle_, "DiffuserReset"));
RET_CHECK(reset_ptr_) << dlerror();
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
@ -159,9 +159,9 @@ class StableDiffusionIterateCalculator : public Node {
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
return (*create_ptr_)(a);
}
bool DiffuserReset(const char* a, int b, int c,
const std::vector<DiffuserPluginTensor>* d) {
return (*reset_ptr_)(context_, a, b, c, d);
bool DiffuserReset(const char* a, int b, int c, float d,
const std::vector<DiffuserPluginTensor>* e) {
return (*reset_ptr_)(context_, a, b, c, d, e);
}
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
@ -170,7 +170,8 @@ class StableDiffusionIterateCalculator : public Node {
void* handle_ = nullptr;
DiffuserContext* context_ = nullptr;
DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*);
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float,
const void*);
int (*iterate_ptr_)(DiffuserContext*, int, int);
int (*decode_ptr_)(DiffuserContext*, uint8_t*);
void (*delete_ptr_)(DiffuserContext*);
@ -221,8 +222,8 @@ absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
.performance_hint = kDiffuserPerformanceHintHigh,
};
config.plugins_strength = options.plugins_strength();
RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f)
RET_CHECK(options.plugins_strength() >= 0.0f ||
options.plugins_strength() <= 1.0f)
<< "The value of plugins_strength must be in the range of [0, 1].";
context_ = DiffuserCreate(&config);
RET_CHECK(context_);
@ -239,7 +240,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
if (kIterationIn(cc).IsEmpty()) {
const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
options.plugins_strength(), &plugin_tensors));
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
options.output_image_height());
@ -252,8 +254,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
// Extract text embedding on first iteration.
if (iteration == 0) {
const auto plugin_tensors = GetPluginTensors(cc);
RET_CHECK(
DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
options.plugins_strength(), &plugin_tensors));
}
RET_CHECK(DiffuserIterate(steps, iteration));