AbstractPhil commited on
Commit
75ce7bd
·
1 Parent(s): 7229198
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -89,12 +89,20 @@ def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, us
89
  if noise > 0:
90
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
91
 
 
92
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
93
  neg_embeds = torch.zeros_like(prompt_embeds)
94
 
 
 
 
 
 
95
  image = pipe(
96
  prompt_embeds=prompt_embeds,
 
97
  negative_prompt_embeds=neg_embeds,
 
98
  num_inference_steps=20,
99
  guidance_scale=5.0
100
  ).images[0]
 
89
  if noise > 0:
90
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
91
 
92
+ # Combine embeddings
93
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
94
  neg_embeds = torch.zeros_like(prompt_embeds)
95
 
96
+ # Compute pooled embeds (mean pooling as default fallback)
97
+ pooled_prompt_embeds = prompt_embeds.mean(dim=1)
98
+ pooled_neg_embeds = neg_embeds.mean(dim=1)
99
+
100
+ # SDXL generation with required pooled embeddings
101
  image = pipe(
102
  prompt_embeds=prompt_embeds,
103
+ pooled_prompt_embeds=pooled_prompt_embeds,
104
  negative_prompt_embeds=neg_embeds,
105
+ negative_pooled_prompt_embeds=pooled_neg_embeds,
106
  num_inference_steps=20,
107
  guidance_scale=5.0
108
  ).images[0]