Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ from safetensors.torch import load_file
|
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
12 |
from configs import T5_SHUNT_REPOS
|
|
|
13 |
|
14 |
# βββ Global Variables βββββββββββββββββββββββββββββββββββββββββ
|
15 |
t5_tok = None
|
@@ -33,6 +34,7 @@ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
|
33 |
|
34 |
# βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββ
|
35 |
def load_adapter(repo, filename, config, device):
|
|
|
36 |
from safetensors.torch import safe_open
|
37 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
38 |
|
@@ -46,29 +48,42 @@ def load_adapter(repo, filename, config, device):
|
|
46 |
|
47 |
def plot_heat(mat, title):
|
48 |
"""Create heatmap visualization with proper shape handling"""
|
49 |
-
|
|
|
|
|
50 |
|
51 |
# Ensure we have a 2D array for visualization
|
52 |
if len(mat.shape) == 1:
|
|
|
53 |
mat = mat.reshape(1, -1)
|
54 |
elif len(mat.shape) == 3:
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
elif len(mat.shape) > 3:
|
|
|
57 |
mat = mat.reshape(-1, mat.shape[-1])
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
plt.
|
|
|
|
|
65 |
|
|
|
66 |
buf = io.BytesIO()
|
67 |
plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
|
68 |
buf.seek(0)
|
69 |
pil_image = Image.open(buf)
|
70 |
-
plt.close(
|
71 |
-
|
|
|
|
|
72 |
|
73 |
def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
|
74 |
"""Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
|
@@ -92,15 +107,18 @@ def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
|
|
92 |
|
93 |
with torch.no_grad():
|
94 |
# CLIP-L: [0] = sequence, [1] = pooled
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
97 |
|
98 |
-
# CLIP-G: [0] = pooled, [1] = sequence
|
99 |
-
clip_g_output = pipe.text_encoder_2(tokens_g)
|
100 |
clip_g_embeds = clip_g_output[1] # sequence embeddings
|
101 |
pooled_embeds = clip_g_output[0] # pooled embeddings
|
102 |
|
103 |
-
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
|
104 |
neg_clip_g_embeds = neg_clip_g_output[1]
|
105 |
neg_pooled_embeds = neg_clip_g_output[0]
|
106 |
|
@@ -139,6 +157,9 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
139 |
if seed != -1:
|
140 |
torch.manual_seed(seed)
|
141 |
np.random.seed(seed)
|
|
|
|
|
|
|
142 |
|
143 |
# Set scheduler
|
144 |
if scheduler_name in SCHEDULERS:
|
@@ -148,7 +169,9 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
148 |
t5_ids = t5_tok(
|
149 |
prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
|
150 |
).input_ids.to(device)
|
151 |
-
|
|
|
|
|
152 |
|
153 |
# Get CLIP embeddings
|
154 |
clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
|
@@ -159,41 +182,83 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
159 |
|
160 |
# Apply CLIP-L adapter
|
161 |
if adapter_l is not None:
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
else:
|
174 |
clip_l_mod = clip_embeds["clip_l"]
|
175 |
delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
|
176 |
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
|
177 |
-
g_pred_l
|
|
|
178 |
|
179 |
# Apply CLIP-G adapter
|
180 |
if adapter_g is not None:
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
else:
|
193 |
clip_g_mod = clip_embeds["clip_g"]
|
194 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
195 |
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
|
196 |
-
g_pred_g
|
|
|
197 |
|
198 |
# Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
|
199 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
@@ -210,18 +275,18 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
210 |
width=width,
|
211 |
height=height,
|
212 |
num_images_per_prompt=1,
|
213 |
-
generator=
|
214 |
).images[0]
|
215 |
|
216 |
# Create visualizations
|
217 |
-
delta_l_viz = plot_heat(delta_l_final.squeeze()
|
218 |
-
gate_l_viz = plot_heat(gate_l_scaled.squeeze().
|
219 |
-
delta_g_viz = plot_heat(delta_g_final.squeeze()
|
220 |
-
gate_g_viz = plot_heat(gate_g_scaled.squeeze().
|
221 |
|
222 |
# Statistics
|
223 |
-
stats_l = f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}"
|
224 |
-
stats_g = f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}"
|
225 |
|
226 |
return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
|
227 |
|
@@ -286,7 +351,7 @@ def create_interface():
|
|
286 |
width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
|
287 |
height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
|
288 |
|
289 |
-
seed = gr.Number(value=-1, label="Seed (-1 for random)")
|
290 |
|
291 |
generate_btn = gr.Button("π Generate Image", variant="primary", size="lg")
|
292 |
|
|
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
12 |
from configs import T5_SHUNT_REPOS
|
13 |
+
import io
|
14 |
|
15 |
# βββ Global Variables βββββββββββββββββββββββββββββββββββββββββ
|
16 |
t5_tok = None
|
|
|
34 |
|
35 |
# βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββ
|
36 |
def load_adapter(repo, filename, config, device):
|
37 |
+
"""Load adapter from safetensors file"""
|
38 |
from safetensors.torch import safe_open
|
39 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
40 |
|
|
|
48 |
|
49 |
def plot_heat(mat, title):
|
50 |
"""Create heatmap visualization with proper shape handling"""
|
51 |
+
# Handle different input shapes
|
52 |
+
if isinstance(mat, torch.Tensor):
|
53 |
+
mat = mat.detach().cpu().numpy()
|
54 |
|
55 |
# Ensure we have a 2D array for visualization
|
56 |
if len(mat.shape) == 1:
|
57 |
+
# 1D array - reshape to single row
|
58 |
mat = mat.reshape(1, -1)
|
59 |
elif len(mat.shape) == 3:
|
60 |
+
# 3D array - average over batch dimension
|
61 |
+
if mat.shape[0] == 1:
|
62 |
+
mat = mat.squeeze(0)
|
63 |
+
else:
|
64 |
+
mat = mat.mean(axis=0)
|
65 |
elif len(mat.shape) > 3:
|
66 |
+
# Flatten higher dimensions
|
67 |
mat = mat.reshape(-1, mat.shape[-1])
|
68 |
|
69 |
+
# Create figure with proper DPI
|
70 |
+
plt.figure(figsize=(8, 4), dpi=100)
|
71 |
+
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest')
|
72 |
+
plt.title(title, fontsize=12, fontweight='bold')
|
73 |
+
plt.xlabel("Token Position")
|
74 |
+
plt.ylabel("Feature Dimension")
|
75 |
+
plt.colorbar(shrink=0.8)
|
76 |
+
plt.tight_layout()
|
77 |
|
78 |
+
# Convert to PIL Image
|
79 |
buf = io.BytesIO()
|
80 |
plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
|
81 |
buf.seek(0)
|
82 |
pil_image = Image.open(buf)
|
83 |
+
plt.close()
|
84 |
+
|
85 |
+
# Convert to numpy array for Gradio
|
86 |
+
return np.array(pil_image)
|
87 |
|
88 |
def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
|
89 |
"""Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
|
|
|
107 |
|
108 |
with torch.no_grad():
|
109 |
# CLIP-L: [0] = sequence, [1] = pooled
|
110 |
+
clip_l_output = pipe.text_encoder(tokens_l, output_hidden_states=False)
|
111 |
+
clip_l_embeds = clip_l_output[0]
|
112 |
+
|
113 |
+
neg_clip_l_output = pipe.text_encoder(neg_tokens_l, output_hidden_states=False)
|
114 |
+
neg_clip_l_embeds = neg_clip_l_output[0]
|
115 |
|
116 |
+
# CLIP-G: [0] = pooled, [1] = sequence
|
117 |
+
clip_g_output = pipe.text_encoder_2(tokens_g, output_hidden_states=False)
|
118 |
clip_g_embeds = clip_g_output[1] # sequence embeddings
|
119 |
pooled_embeds = clip_g_output[0] # pooled embeddings
|
120 |
|
121 |
+
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g, output_hidden_states=False)
|
122 |
neg_clip_g_embeds = neg_clip_g_output[1]
|
123 |
neg_pooled_embeds = neg_clip_g_output[0]
|
124 |
|
|
|
157 |
if seed != -1:
|
158 |
torch.manual_seed(seed)
|
159 |
np.random.seed(seed)
|
160 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
161 |
+
else:
|
162 |
+
generator = None
|
163 |
|
164 |
# Set scheduler
|
165 |
if scheduler_name in SCHEDULERS:
|
|
|
169 |
t5_ids = t5_tok(
|
170 |
prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
|
171 |
).input_ids.to(device)
|
172 |
+
|
173 |
+
with torch.no_grad():
|
174 |
+
t5_seq = t5_mod(t5_ids).last_hidden_state
|
175 |
|
176 |
# Get CLIP embeddings
|
177 |
clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
|
|
|
182 |
|
183 |
# Apply CLIP-L adapter
|
184 |
if adapter_l is not None:
|
185 |
+
with torch.no_grad():
|
186 |
+
# Run adapter forward pass
|
187 |
+
adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float())
|
188 |
+
|
189 |
+
# Unpack outputs (ensure correct number of outputs)
|
190 |
+
if len(adapter_output) == 8:
|
191 |
+
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output
|
192 |
+
else:
|
193 |
+
# Handle different return formats
|
194 |
+
anchor_l = adapter_output[0]
|
195 |
+
delta_l = adapter_output[1]
|
196 |
+
gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l)
|
197 |
+
tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
|
198 |
+
g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
|
199 |
+
|
200 |
+
# Apply gate scaling
|
201 |
+
gate_l_scaled = torch.sigmoid(gate_l) * gate_prob
|
202 |
+
|
203 |
+
# Compute final delta with strength and gate
|
204 |
+
delta_l_final = delta_l * strength * gate_l_scaled
|
205 |
+
|
206 |
+
# Apply delta to embeddings
|
207 |
+
clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
|
208 |
+
|
209 |
+
# Apply anchor mixing if enabled
|
210 |
+
if use_anchor:
|
211 |
+
clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
|
212 |
+
|
213 |
+
# Add noise if specified
|
214 |
+
if noise > 0:
|
215 |
+
clip_l_mod += torch.randn_like(clip_l_mod) * noise
|
216 |
else:
|
217 |
clip_l_mod = clip_embeds["clip_l"]
|
218 |
delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
|
219 |
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
|
220 |
+
g_pred_l = torch.tensor(0.0)
|
221 |
+
tau_l = torch.tensor(0.0)
|
222 |
|
223 |
# Apply CLIP-G adapter
|
224 |
if adapter_g is not None:
|
225 |
+
with torch.no_grad():
|
226 |
+
# Run adapter forward pass
|
227 |
+
adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float())
|
228 |
+
|
229 |
+
# Unpack outputs (ensure correct number of outputs)
|
230 |
+
if len(adapter_output) == 8:
|
231 |
+
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output
|
232 |
+
else:
|
233 |
+
# Handle different return formats
|
234 |
+
anchor_g = adapter_output[0]
|
235 |
+
delta_g = adapter_output[1]
|
236 |
+
gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g)
|
237 |
+
tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
|
238 |
+
g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
|
239 |
+
|
240 |
+
# Apply gate scaling
|
241 |
+
gate_g_scaled = torch.sigmoid(gate_g) * gate_prob
|
242 |
+
|
243 |
+
# Compute final delta with strength and gate
|
244 |
+
delta_g_final = delta_g * strength * gate_g_scaled
|
245 |
+
|
246 |
+
# Apply delta to embeddings
|
247 |
+
clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
|
248 |
+
|
249 |
+
# Apply anchor mixing if enabled
|
250 |
+
if use_anchor:
|
251 |
+
clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
|
252 |
+
|
253 |
+
# Add noise if specified
|
254 |
+
if noise > 0:
|
255 |
+
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
256 |
else:
|
257 |
clip_g_mod = clip_embeds["clip_g"]
|
258 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
259 |
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
|
260 |
+
g_pred_g = torch.tensor(0.0)
|
261 |
+
tau_g = torch.tensor(0.0)
|
262 |
|
263 |
# Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
|
264 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
|
|
275 |
width=width,
|
276 |
height=height,
|
277 |
num_images_per_prompt=1,
|
278 |
+
generator=generator
|
279 |
).images[0]
|
280 |
|
281 |
# Create visualizations
|
282 |
+
delta_l_viz = plot_heat(delta_l_final.squeeze(), "CLIP-L Delta Values")
|
283 |
+
gate_l_viz = plot_heat(gate_l_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-L Gate Activations")
|
284 |
+
delta_g_viz = plot_heat(delta_g_final.squeeze(), "CLIP-G Delta Values")
|
285 |
+
gate_g_viz = plot_heat(gate_g_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-G Gate Activations")
|
286 |
|
287 |
# Statistics
|
288 |
+
stats_l = f"g_pred_l: {float(g_pred_l.mean().item() if hasattr(g_pred_l, 'mean') else g_pred_l):.3f}, Ο_l: {float(tau_l.mean().item() if hasattr(tau_l, 'mean') else tau_l):.3f}"
|
289 |
+
stats_g = f"g_pred_g: {float(g_pred_g.mean().item() if hasattr(g_pred_g, 'mean') else g_pred_g):.3f}, Ο_g: {float(tau_g.mean().item() if hasattr(tau_g, 'mean') else tau_g):.3f}"
|
290 |
|
291 |
return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
|
292 |
|
|
|
351 |
width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
|
352 |
height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
|
353 |
|
354 |
+
seed = gr.Number(value=-1, label="Seed (-1 for random)", precision=0)
|
355 |
|
356 |
generate_btn = gr.Button("π Generate Image", variant="primary", size="lg")
|
357 |
|