AbstractPhil commited on
Commit
e1b090f
Β·
verified Β·
1 Parent(s): 0bf3e20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -14,7 +14,7 @@ from configs import T5_SHUNT_REPOS
14
  # ─── Device & Model Setup ─────────────────────────────────────
15
  # Don't initialize CUDA here for ZeroGPU compatibility
16
  device = None # Will be set inside the GPU function
17
- dtype = torch.float32
18
 
19
  # Don't load models here - will load inside GPU function
20
  t5_tok = None
@@ -66,7 +66,7 @@ def plot_heat(mat, title):
66
  return buf
67
 
68
  # ─── SDXL Text Encoding ───────────────────────────────────────
69
- def encode_sdxl_prompt(pipe, prompt, negative_prompt="", device=device):
70
  """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
71
 
72
  # Tokenize for both encoders
@@ -128,6 +128,10 @@ def encode_sdxl_prompt(pipe, prompt, negative_prompt="", device=device):
128
  "neg_pooled": neg_pooled_embeds
129
  }
130
 
 
 
 
 
131
 
132
  # ─── Inference ────────────────────────────────────────────
133
  @spaces.GPU
@@ -186,8 +190,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
186
 
187
  # Apply CLIP-L adapter
188
  if adapter_l is not None:
189
- clip_l_in = clip_embeds["clip_l"].to(torch.float32)
190
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
191
  gate_l_scaled = gate_l * gate_prob
192
  delta_l_final = delta_l * strength * gate_l_scaled
193
  clip_l_mod = clip_embeds["clip_l"] + delta_l_final
@@ -204,10 +207,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
204
 
205
  # Apply CLIP-G adapter
206
  if adapter_g is not None:
207
- # Float32 adapter input
208
- clip_g_in = clip_embeds["clip_g"].to(torch.float32)
209
-
210
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
211
  gate_g_scaled = gate_g * gate_prob
212
  delta_g_final = delta_g * strength * gate_g_scaled
213
  clip_g_mod = clip_embeds["clip_g"] + delta_g_final
@@ -243,9 +243,9 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
243
  return (
244
  image,
245
  plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
246
- plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
247
  plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
248
- plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
249
  f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
250
  f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
251
  )
 
14
  # ─── Device & Model Setup ─────────────────────────────────────
15
  # Don't initialize CUDA here for ZeroGPU compatibility
16
  device = None # Will be set inside the GPU function
17
+ dtype = torch.float16
18
 
19
  # Don't load models here - will load inside GPU function
20
  t5_tok = None
 
66
  return buf
67
 
68
  # ─── SDXL Text Encoding ───────────────────────────────────────
69
+ def encode_sdxl_prompt(prompt, negative_prompt=""):
70
  """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
71
 
72
  # Tokenize for both encoders
 
128
  "neg_pooled": neg_pooled_embeds
129
  }
130
 
131
+ # ─── Inference ────────────────────────────────────────────────
132
+ @torch.no_grad()
133
+ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
 
136
  # ─── Inference ────────────────────────────────────────────
137
  @spaces.GPU
 
190
 
191
  # Apply CLIP-L adapter
192
  if adapter_l is not None:
193
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
 
194
  gate_l_scaled = gate_l * gate_prob
195
  delta_l_final = delta_l * strength * gate_l_scaled
196
  clip_l_mod = clip_embeds["clip_l"] + delta_l_final
 
207
 
208
  # Apply CLIP-G adapter
209
  if adapter_g is not None:
210
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
 
 
 
211
  gate_g_scaled = gate_g * gate_prob
212
  delta_g_final = delta_g * strength * gate_g_scaled
213
  clip_g_mod = clip_embeds["clip_g"] + delta_g_final
 
243
  return (
244
  image,
245
  plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
246
+ plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1), "Gate CLIP-L"),
247
  plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
248
+ plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1), "Gate CLIP-G"),
249
  f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
250
  f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
251
  )