Internal changes
PiperOrigin-RevId: 561398473
This commit is contained in:
parent
5434b840f6
commit
f60da2120d
|
@ -61,7 +61,6 @@ typedef struct {
|
||||||
int image_width;
|
int image_width;
|
||||||
int image_height;
|
int image_height;
|
||||||
int run_unet_with_plugins;
|
int run_unet_with_plugins;
|
||||||
float plugins_strength;
|
|
||||||
DiffuserEnvironmentOptions env_options;
|
DiffuserEnvironmentOptions env_options;
|
||||||
} DiffuserConfig;
|
} DiffuserConfig;
|
||||||
|
|
||||||
|
@ -76,7 +75,7 @@ typedef struct {
|
||||||
|
|
||||||
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
|
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
|
||||||
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
|
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
|
||||||
const char*, int, int, const void*);
|
const char*, int, int, float, const void*);
|
||||||
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
|
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
|
||||||
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
|
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
|
||||||
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT
|
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT
|
||||||
|
|
|
@ -141,7 +141,7 @@ class StableDiffusionIterateCalculator : public Node {
|
||||||
dlsym(handle_, "DiffuserCreate"));
|
dlsym(handle_, "DiffuserCreate"));
|
||||||
RET_CHECK(create_ptr_) << dlerror();
|
RET_CHECK(create_ptr_) << dlerror();
|
||||||
reset_ptr_ =
|
reset_ptr_ =
|
||||||
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int,
|
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int, float,
|
||||||
const void*)>(dlsym(handle_, "DiffuserReset"));
|
const void*)>(dlsym(handle_, "DiffuserReset"));
|
||||||
RET_CHECK(reset_ptr_) << dlerror();
|
RET_CHECK(reset_ptr_) << dlerror();
|
||||||
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
|
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
|
||||||
|
@ -159,9 +159,9 @@ class StableDiffusionIterateCalculator : public Node {
|
||||||
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
|
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
|
||||||
return (*create_ptr_)(a);
|
return (*create_ptr_)(a);
|
||||||
}
|
}
|
||||||
bool DiffuserReset(const char* a, int b, int c,
|
bool DiffuserReset(const char* a, int b, int c, float d,
|
||||||
const std::vector<DiffuserPluginTensor>* d) {
|
const std::vector<DiffuserPluginTensor>* e) {
|
||||||
return (*reset_ptr_)(context_, a, b, c, d);
|
return (*reset_ptr_)(context_, a, b, c, d, e);
|
||||||
}
|
}
|
||||||
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
|
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
|
||||||
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
|
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
|
||||||
|
@ -170,7 +170,8 @@ class StableDiffusionIterateCalculator : public Node {
|
||||||
void* handle_ = nullptr;
|
void* handle_ = nullptr;
|
||||||
DiffuserContext* context_ = nullptr;
|
DiffuserContext* context_ = nullptr;
|
||||||
DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
|
DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
|
||||||
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*);
|
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float,
|
||||||
|
const void*);
|
||||||
int (*iterate_ptr_)(DiffuserContext*, int, int);
|
int (*iterate_ptr_)(DiffuserContext*, int, int);
|
||||||
int (*decode_ptr_)(DiffuserContext*, uint8_t*);
|
int (*decode_ptr_)(DiffuserContext*, uint8_t*);
|
||||||
void (*delete_ptr_)(DiffuserContext*);
|
void (*delete_ptr_)(DiffuserContext*);
|
||||||
|
@ -221,8 +222,8 @@ absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
|
||||||
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
|
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
|
||||||
.performance_hint = kDiffuserPerformanceHintHigh,
|
.performance_hint = kDiffuserPerformanceHintHigh,
|
||||||
};
|
};
|
||||||
config.plugins_strength = options.plugins_strength();
|
RET_CHECK(options.plugins_strength() >= 0.0f ||
|
||||||
RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f)
|
options.plugins_strength() <= 1.0f)
|
||||||
<< "The value of plugins_strength must be in the range of [0, 1].";
|
<< "The value of plugins_strength must be in the range of [0, 1].";
|
||||||
context_ = DiffuserCreate(&config);
|
context_ = DiffuserCreate(&config);
|
||||||
RET_CHECK(context_);
|
RET_CHECK(context_);
|
||||||
|
@ -239,7 +240,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
||||||
|
|
||||||
if (kIterationIn(cc).IsEmpty()) {
|
if (kIterationIn(cc).IsEmpty()) {
|
||||||
const auto plugin_tensors = GetPluginTensors(cc);
|
const auto plugin_tensors = GetPluginTensors(cc);
|
||||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
|
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||||
|
options.plugins_strength(), &plugin_tensors));
|
||||||
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
||||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||||
options.output_image_height());
|
options.output_image_height());
|
||||||
|
@ -252,8 +254,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
||||||
// Extract text embedding on first iteration.
|
// Extract text embedding on first iteration.
|
||||||
if (iteration == 0) {
|
if (iteration == 0) {
|
||||||
const auto plugin_tensors = GetPluginTensors(cc);
|
const auto plugin_tensors = GetPluginTensors(cc);
|
||||||
RET_CHECK(
|
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||||
DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
|
options.plugins_strength(), &plugin_tensors));
|
||||||
}
|
}
|
||||||
|
|
||||||
RET_CHECK(DiffuserIterate(steps, iteration));
|
RET_CHECK(DiffuserIterate(steps, iteration));
|
||||||
|
|
Loading…
Reference in New Issue
Block a user