Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,7 @@ from configs import T5_SHUNT_REPOS
|
|
14 |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
|
15 |
# Don't initialize CUDA here for ZeroGPU compatibility
|
16 |
device = None # Will be set inside the GPU function
|
17 |
-
dtype = torch.
|
18 |
|
19 |
# Don't load models here - will load inside GPU function
|
20 |
t5_tok = None
|
@@ -66,7 +66,7 @@ def plot_heat(mat, title):
|
|
66 |
return buf
|
67 |
|
68 |
# βββ SDXL Text Encoding βββββββββββββββββββββββββββββββββββββββ
|
69 |
-
def encode_sdxl_prompt(
|
70 |
"""Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
|
71 |
|
72 |
# Tokenize for both encoders
|
@@ -128,6 +128,10 @@ def encode_sdxl_prompt(pipe, prompt, negative_prompt="", device=device):
|
|
128 |
"neg_pooled": neg_pooled_embeds
|
129 |
}
|
130 |
|
|
|
|
|
|
|
|
|
131 |
|
132 |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββ
|
133 |
@spaces.GPU
|
@@ -186,8 +190,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
186 |
|
187 |
# Apply CLIP-L adapter
|
188 |
if adapter_l is not None:
|
189 |
-
|
190 |
-
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
|
191 |
gate_l_scaled = gate_l * gate_prob
|
192 |
delta_l_final = delta_l * strength * gate_l_scaled
|
193 |
clip_l_mod = clip_embeds["clip_l"] + delta_l_final
|
@@ -204,10 +207,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
204 |
|
205 |
# Apply CLIP-G adapter
|
206 |
if adapter_g is not None:
|
207 |
-
|
208 |
-
clip_g_in = clip_embeds["clip_g"].to(torch.float32)
|
209 |
-
|
210 |
-
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
|
211 |
gate_g_scaled = gate_g * gate_prob
|
212 |
delta_g_final = delta_g * strength * gate_g_scaled
|
213 |
clip_g_mod = clip_embeds["clip_g"] + delta_g_final
|
@@ -243,9 +243,9 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
243 |
return (
|
244 |
image,
|
245 |
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"),
|
246 |
-
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
|
247 |
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"),
|
248 |
-
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
|
249 |
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}",
|
250 |
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}"
|
251 |
)
|
|
|
14 |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
|
15 |
# Don't initialize CUDA here for ZeroGPU compatibility
|
16 |
device = None # Will be set inside the GPU function
|
17 |
+
dtype = torch.float16
|
18 |
|
19 |
# Don't load models here - will load inside GPU function
|
20 |
t5_tok = None
|
|
|
66 |
return buf
|
67 |
|
68 |
# βββ SDXL Text Encoding βββββββββββββββββββββββββββββββββββββββ
|
69 |
+
def encode_sdxl_prompt(prompt, negative_prompt=""):
|
70 |
"""Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
|
71 |
|
72 |
# Tokenize for both encoders
|
|
|
128 |
"neg_pooled": neg_pooled_embeds
|
129 |
}
|
130 |
|
131 |
+
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
|
132 |
+
@torch.no_grad()
|
133 |
+
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
134 |
+
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
135 |
|
136 |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββ
|
137 |
@spaces.GPU
|
|
|
190 |
|
191 |
# Apply CLIP-L adapter
|
192 |
if adapter_l is not None:
|
193 |
+
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
|
|
|
194 |
gate_l_scaled = gate_l * gate_prob
|
195 |
delta_l_final = delta_l * strength * gate_l_scaled
|
196 |
clip_l_mod = clip_embeds["clip_l"] + delta_l_final
|
|
|
207 |
|
208 |
# Apply CLIP-G adapter
|
209 |
if adapter_g is not None:
|
210 |
+
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
|
|
|
|
|
|
|
211 |
gate_g_scaled = gate_g * gate_prob
|
212 |
delta_g_final = delta_g * strength * gate_g_scaled
|
213 |
clip_g_mod = clip_embeds["clip_g"] + delta_g_final
|
|
|
243 |
return (
|
244 |
image,
|
245 |
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"),
|
246 |
+
plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1), "Gate CLIP-L"),
|
247 |
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"),
|
248 |
+
plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1), "Gate CLIP-G"),
|
249 |
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}",
|
250 |
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}"
|
251 |
)
|