AbstractPhil commited on
Commit
b6b9cb1
Β·
verified Β·
1 Parent(s): 5759aab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -48
app.py CHANGED
@@ -10,6 +10,7 @@ from safetensors.torch import load_file
10
  from huggingface_hub import hf_hub_download
11
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
12
  from configs import T5_SHUNT_REPOS
 
13
 
14
  # ─── Global Variables ─────────────────────────────────────────
15
  t5_tok = None
@@ -33,6 +34,7 @@ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
33
 
34
  # ─── Helper Functions ─────────────────────────────────────────
35
  def load_adapter(repo, filename, config, device):
 
36
  from safetensors.torch import safe_open
37
  path = hf_hub_download(repo_id=repo, filename=filename)
38
 
@@ -46,29 +48,42 @@ def load_adapter(repo, filename, config, device):
46
 
47
  def plot_heat(mat, title):
48
  """Create heatmap visualization with proper shape handling"""
49
- import io
 
 
50
 
51
  # Ensure we have a 2D array for visualization
52
  if len(mat.shape) == 1:
 
53
  mat = mat.reshape(1, -1)
54
  elif len(mat.shape) == 3:
55
- mat = mat.mean(axis=0)
 
 
 
 
56
  elif len(mat.shape) > 3:
 
57
  mat = mat.reshape(-1, mat.shape[-1])
58
 
59
- fig, ax = plt.subplots(figsize=(8, 4), dpi=100)
60
- im = ax.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
61
- ax.set_title(title, fontsize=12, fontweight='bold')
62
- ax.set_xlabel("Token Position")
63
- ax.set_ylabel("Feature Dimension")
64
- plt.colorbar(im, ax=ax, shrink=0.8)
 
 
65
 
 
66
  buf = io.BytesIO()
67
  plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
68
  buf.seek(0)
69
  pil_image = Image.open(buf)
70
- plt.close(fig)
71
- return pil_image
 
 
72
 
73
  def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
74
  """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
@@ -92,15 +107,18 @@ def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
92
 
93
  with torch.no_grad():
94
  # CLIP-L: [0] = sequence, [1] = pooled
95
- clip_l_embeds = pipe.text_encoder(tokens_l)[0]
96
- neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
 
 
 
97
 
98
- # CLIP-G: [0] = pooled, [1] = sequence (different from CLIP-L!)
99
- clip_g_output = pipe.text_encoder_2(tokens_g)
100
  clip_g_embeds = clip_g_output[1] # sequence embeddings
101
  pooled_embeds = clip_g_output[0] # pooled embeddings
102
 
103
- neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
104
  neg_clip_g_embeds = neg_clip_g_output[1]
105
  neg_pooled_embeds = neg_clip_g_output[0]
106
 
@@ -139,6 +157,9 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
139
  if seed != -1:
140
  torch.manual_seed(seed)
141
  np.random.seed(seed)
 
 
 
142
 
143
  # Set scheduler
144
  if scheduler_name in SCHEDULERS:
@@ -148,7 +169,9 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
148
  t5_ids = t5_tok(
149
  prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
150
  ).input_ids.to(device)
151
- t5_seq = t5_mod(t5_ids).last_hidden_state
 
 
152
 
153
  # Get CLIP embeddings
154
  clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
@@ -159,41 +182,83 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
159
 
160
  # Apply CLIP-L adapter
161
  if adapter_l is not None:
162
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(
163
- t5_seq.float(), clip_embeds["clip_l"].float()
164
- )
165
- gate_l_scaled = gate_l * gate_prob
166
- delta_l_final = delta_l * strength * gate_l_scaled
167
- clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
168
-
169
- if use_anchor:
170
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
171
- if noise > 0:
172
- clip_l_mod += torch.randn_like(clip_l_mod) * noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  else:
174
  clip_l_mod = clip_embeds["clip_l"]
175
  delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
176
  gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
177
- g_pred_l, tau_l = torch.tensor(0.0), torch.tensor(0.0)
 
178
 
179
  # Apply CLIP-G adapter
180
  if adapter_g is not None:
181
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(
182
- t5_seq.float(), clip_embeds["clip_g"].float()
183
- )
184
- gate_g_scaled = gate_g * gate_prob
185
- delta_g_final = delta_g * strength * gate_g_scaled
186
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
187
-
188
- if use_anchor:
189
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
190
- if noise > 0:
191
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  else:
193
  clip_g_mod = clip_embeds["clip_g"]
194
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
195
  gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
196
- g_pred_g, tau_g = torch.tensor(0.0), torch.tensor(0.0)
 
197
 
198
  # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
199
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
@@ -210,18 +275,18 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
210
  width=width,
211
  height=height,
212
  num_images_per_prompt=1,
213
- generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
214
  ).images[0]
215
 
216
  # Create visualizations
217
- delta_l_viz = plot_heat(delta_l_final.squeeze().cpu().numpy(), "CLIP-L Delta Values")
218
- gate_l_viz = plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-L Gate Activations")
219
- delta_g_viz = plot_heat(delta_g_final.squeeze().cpu().numpy(), "CLIP-G Delta Values")
220
- gate_g_viz = plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-G Gate Activations")
221
 
222
  # Statistics
223
- stats_l = f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}"
224
- stats_g = f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
225
 
226
  return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
227
 
@@ -286,7 +351,7 @@ def create_interface():
286
  width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
287
  height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
288
 
289
- seed = gr.Number(value=-1, label="Seed (-1 for random)")
290
 
291
  generate_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
292
 
 
10
  from huggingface_hub import hf_hub_download
11
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
12
  from configs import T5_SHUNT_REPOS
13
+ import io
14
 
15
  # ─── Global Variables ─────────────────────────────────────────
16
  t5_tok = None
 
34
 
35
  # ─── Helper Functions ─────────────────────────────────────────
36
  def load_adapter(repo, filename, config, device):
37
+ """Load adapter from safetensors file"""
38
  from safetensors.torch import safe_open
39
  path = hf_hub_download(repo_id=repo, filename=filename)
40
 
 
48
 
49
  def plot_heat(mat, title):
50
  """Create heatmap visualization with proper shape handling"""
51
+ # Handle different input shapes
52
+ if isinstance(mat, torch.Tensor):
53
+ mat = mat.detach().cpu().numpy()
54
 
55
  # Ensure we have a 2D array for visualization
56
  if len(mat.shape) == 1:
57
+ # 1D array - reshape to single row
58
  mat = mat.reshape(1, -1)
59
  elif len(mat.shape) == 3:
60
+ # 3D array - average over batch dimension
61
+ if mat.shape[0] == 1:
62
+ mat = mat.squeeze(0)
63
+ else:
64
+ mat = mat.mean(axis=0)
65
  elif len(mat.shape) > 3:
66
+ # Flatten higher dimensions
67
  mat = mat.reshape(-1, mat.shape[-1])
68
 
69
+ # Create figure with proper DPI
70
+ plt.figure(figsize=(8, 4), dpi=100)
71
+ plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest')
72
+ plt.title(title, fontsize=12, fontweight='bold')
73
+ plt.xlabel("Token Position")
74
+ plt.ylabel("Feature Dimension")
75
+ plt.colorbar(shrink=0.8)
76
+ plt.tight_layout()
77
 
78
+ # Convert to PIL Image
79
  buf = io.BytesIO()
80
  plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
81
  buf.seek(0)
82
  pil_image = Image.open(buf)
83
+ plt.close()
84
+
85
+ # Convert to numpy array for Gradio
86
+ return np.array(pil_image)
87
 
88
  def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
89
  """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
 
107
 
108
  with torch.no_grad():
109
  # CLIP-L: [0] = sequence, [1] = pooled
110
+ clip_l_output = pipe.text_encoder(tokens_l, output_hidden_states=False)
111
+ clip_l_embeds = clip_l_output[0]
112
+
113
+ neg_clip_l_output = pipe.text_encoder(neg_tokens_l, output_hidden_states=False)
114
+ neg_clip_l_embeds = neg_clip_l_output[0]
115
 
116
+ # CLIP-G: [0] = pooled, [1] = sequence
117
+ clip_g_output = pipe.text_encoder_2(tokens_g, output_hidden_states=False)
118
  clip_g_embeds = clip_g_output[1] # sequence embeddings
119
  pooled_embeds = clip_g_output[0] # pooled embeddings
120
 
121
+ neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g, output_hidden_states=False)
122
  neg_clip_g_embeds = neg_clip_g_output[1]
123
  neg_pooled_embeds = neg_clip_g_output[0]
124
 
 
157
  if seed != -1:
158
  torch.manual_seed(seed)
159
  np.random.seed(seed)
160
+ generator = torch.Generator(device=device).manual_seed(seed)
161
+ else:
162
+ generator = None
163
 
164
  # Set scheduler
165
  if scheduler_name in SCHEDULERS:
 
169
  t5_ids = t5_tok(
170
  prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
171
  ).input_ids.to(device)
172
+
173
+ with torch.no_grad():
174
+ t5_seq = t5_mod(t5_ids).last_hidden_state
175
 
176
  # Get CLIP embeddings
177
  clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
 
182
 
183
  # Apply CLIP-L adapter
184
  if adapter_l is not None:
185
+ with torch.no_grad():
186
+ # Run adapter forward pass
187
+ adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float())
188
+
189
+ # Unpack outputs (ensure correct number of outputs)
190
+ if len(adapter_output) == 8:
191
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output
192
+ else:
193
+ # Handle different return formats
194
+ anchor_l = adapter_output[0]
195
+ delta_l = adapter_output[1]
196
+ gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l)
197
+ tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
198
+ g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
199
+
200
+ # Apply gate scaling
201
+ gate_l_scaled = torch.sigmoid(gate_l) * gate_prob
202
+
203
+ # Compute final delta with strength and gate
204
+ delta_l_final = delta_l * strength * gate_l_scaled
205
+
206
+ # Apply delta to embeddings
207
+ clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
208
+
209
+ # Apply anchor mixing if enabled
210
+ if use_anchor:
211
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
212
+
213
+ # Add noise if specified
214
+ if noise > 0:
215
+ clip_l_mod += torch.randn_like(clip_l_mod) * noise
216
  else:
217
  clip_l_mod = clip_embeds["clip_l"]
218
  delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
219
  gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
220
+ g_pred_l = torch.tensor(0.0)
221
+ tau_l = torch.tensor(0.0)
222
 
223
  # Apply CLIP-G adapter
224
  if adapter_g is not None:
225
+ with torch.no_grad():
226
+ # Run adapter forward pass
227
+ adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float())
228
+
229
+ # Unpack outputs (ensure correct number of outputs)
230
+ if len(adapter_output) == 8:
231
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output
232
+ else:
233
+ # Handle different return formats
234
+ anchor_g = adapter_output[0]
235
+ delta_g = adapter_output[1]
236
+ gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g)
237
+ tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
238
+ g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
239
+
240
+ # Apply gate scaling
241
+ gate_g_scaled = torch.sigmoid(gate_g) * gate_prob
242
+
243
+ # Compute final delta with strength and gate
244
+ delta_g_final = delta_g * strength * gate_g_scaled
245
+
246
+ # Apply delta to embeddings
247
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
248
+
249
+ # Apply anchor mixing if enabled
250
+ if use_anchor:
251
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
252
+
253
+ # Add noise if specified
254
+ if noise > 0:
255
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
256
  else:
257
  clip_g_mod = clip_embeds["clip_g"]
258
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
259
  gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
260
+ g_pred_g = torch.tensor(0.0)
261
+ tau_g = torch.tensor(0.0)
262
 
263
  # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
264
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
 
275
  width=width,
276
  height=height,
277
  num_images_per_prompt=1,
278
+ generator=generator
279
  ).images[0]
280
 
281
  # Create visualizations
282
+ delta_l_viz = plot_heat(delta_l_final.squeeze(), "CLIP-L Delta Values")
283
+ gate_l_viz = plot_heat(gate_l_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-L Gate Activations")
284
+ delta_g_viz = plot_heat(delta_g_final.squeeze(), "CLIP-G Delta Values")
285
+ gate_g_viz = plot_heat(gate_g_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-G Gate Activations")
286
 
287
  # Statistics
288
+ stats_l = f"g_pred_l: {float(g_pred_l.mean().item() if hasattr(g_pred_l, 'mean') else g_pred_l):.3f}, Ο„_l: {float(tau_l.mean().item() if hasattr(tau_l, 'mean') else tau_l):.3f}"
289
+ stats_g = f"g_pred_g: {float(g_pred_g.mean().item() if hasattr(g_pred_g, 'mean') else g_pred_g):.3f}, Ο„_g: {float(tau_g.mean().item() if hasattr(tau_g, 'mean') else tau_g):.3f}"
290
 
291
  return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
292
 
 
351
  width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
352
  height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
353
 
354
+ seed = gr.Number(value=-1, label="Seed (-1 for random)", precision=0)
355
 
356
  generate_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
357