Internal changes
PiperOrigin-RevId: 561398473
This commit is contained in:
parent
5434b840f6
commit
f60da2120d
|
@ -61,7 +61,6 @@ typedef struct {
|
|||
int image_width;
|
||||
int image_height;
|
||||
int run_unet_with_plugins;
|
||||
float plugins_strength;
|
||||
DiffuserEnvironmentOptions env_options;
|
||||
} DiffuserConfig;
|
||||
|
||||
|
@ -76,7 +75,7 @@ typedef struct {
|
|||
|
||||
DG_EXPORT DiffuserContext* DiffuserCreate(const DiffuserConfig*); // NOLINT
|
||||
DG_EXPORT int DiffuserReset(DiffuserContext*, // NOLINT
|
||||
const char*, int, int, const void*);
|
||||
const char*, int, int, float, const void*);
|
||||
DG_EXPORT int DiffuserIterate(DiffuserContext*, int, int); // NOLINT
|
||||
DG_EXPORT int DiffuserDecode(DiffuserContext*, uint8_t*); // NOLINT
|
||||
DG_EXPORT void DiffuserDelete(DiffuserContext*); // NOLINT
|
||||
|
|
|
@ -141,7 +141,7 @@ class StableDiffusionIterateCalculator : public Node {
|
|||
dlsym(handle_, "DiffuserCreate"));
|
||||
RET_CHECK(create_ptr_) << dlerror();
|
||||
reset_ptr_ =
|
||||
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int,
|
||||
reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int, float,
|
||||
const void*)>(dlsym(handle_, "DiffuserReset"));
|
||||
RET_CHECK(reset_ptr_) << dlerror();
|
||||
iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
|
||||
|
@ -159,9 +159,9 @@ class StableDiffusionIterateCalculator : public Node {
|
|||
DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
|
||||
return (*create_ptr_)(a);
|
||||
}
|
||||
bool DiffuserReset(const char* a, int b, int c,
|
||||
const std::vector<DiffuserPluginTensor>* d) {
|
||||
return (*reset_ptr_)(context_, a, b, c, d);
|
||||
bool DiffuserReset(const char* a, int b, int c, float d,
|
||||
const std::vector<DiffuserPluginTensor>* e) {
|
||||
return (*reset_ptr_)(context_, a, b, c, d, e);
|
||||
}
|
||||
bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
|
||||
bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
|
||||
|
@ -170,7 +170,8 @@ class StableDiffusionIterateCalculator : public Node {
|
|||
void* handle_ = nullptr;
|
||||
DiffuserContext* context_ = nullptr;
|
||||
DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
|
||||
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, const void*);
|
||||
int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float,
|
||||
const void*);
|
||||
int (*iterate_ptr_)(DiffuserContext*, int, int);
|
||||
int (*decode_ptr_)(DiffuserContext*, uint8_t*);
|
||||
void (*delete_ptr_)(DiffuserContext*);
|
||||
|
@ -221,8 +222,8 @@ absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
|
|||
.priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
|
||||
.performance_hint = kDiffuserPerformanceHintHigh,
|
||||
};
|
||||
config.plugins_strength = options.plugins_strength();
|
||||
RET_CHECK(config.plugins_strength > 0.0f || config.plugins_strength < 1.0f)
|
||||
RET_CHECK(options.plugins_strength() >= 0.0f ||
|
||||
options.plugins_strength() <= 1.0f)
|
||||
<< "The value of plugins_strength must be in the range of [0, 1].";
|
||||
context_ = DiffuserCreate(&config);
|
||||
RET_CHECK(context_);
|
||||
|
@ -239,7 +240,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
if (kIterationIn(cc).IsEmpty()) {
|
||||
const auto plugin_tensors = GetPluginTensors(cc);
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||
options.plugins_strength(), &plugin_tensors));
|
||||
for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
|
||||
ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
|
||||
options.output_image_height());
|
||||
|
@ -252,8 +254,8 @@ absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
|
|||
// Extract text embedding on first iteration.
|
||||
if (iteration == 0) {
|
||||
const auto plugin_tensors = GetPluginTensors(cc);
|
||||
RET_CHECK(
|
||||
DiffuserReset(prompt.c_str(), steps, rand_seed, &plugin_tensors));
|
||||
RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
|
||||
options.plugins_strength(), &plugin_tensors));
|
||||
}
|
||||
|
||||
RET_CHECK(DiffuserIterate(steps, iteration));
|
||||
|
|
Loading…
Reference in New Issue
Block a user