AbstractPhil commited on
Commit
5759aab
Β·
verified Β·
1 Parent(s): e1b090f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -177
app.py CHANGED
@@ -11,12 +11,7 @@ from huggingface_hub import hf_hub_download
11
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
12
  from configs import T5_SHUNT_REPOS
13
 
14
- # ─── Device & Model Setup ─────────────────────────────────────
15
- # Don't initialize CUDA here for ZeroGPU compatibility
16
- device = None # Will be set inside the GPU function
17
- dtype = torch.float16
18
-
19
- # Don't load models here - will load inside GPU function
20
  t5_tok = None
21
  t5_mod = None
22
  pipe = None
@@ -36,11 +31,9 @@ repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
36
  config_l = T5_SHUNT_REPOS["clip_l"]["config"]
37
  config_g = T5_SHUNT_REPOS["clip_g"]["config"]
38
 
39
- # ─── Loader ───────────────────────────────────────────────────
40
- from safetensors.torch import safe_open
41
-
42
- def load_adapter(repo, filename, config):
43
- # Don't initialize device here
44
  path = hf_hub_download(repo_id=repo, filename=filename)
45
 
46
  model = TwoStreamShuntAdapter(config).eval()
@@ -49,75 +42,67 @@ def load_adapter(repo, filename, config):
49
  for key in f.keys():
50
  tensors[key] = f.get_tensor(key)
51
  model.load_state_dict(tensors)
52
- # Device will be set when called from GPU function
53
- return model
54
 
55
- # ─── Visualization ────────────────────────────────────────────
56
  def plot_heat(mat, title):
 
57
  import io
58
- fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
59
- im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
60
- ax.set_title(title)
61
- plt.colorbar(im, ax=ax)
 
 
 
 
 
 
 
 
 
 
 
 
62
  buf = io.BytesIO()
63
- plt.savefig(buf, format="png", bbox_inches='tight')
64
  buf.seek(0)
 
65
  plt.close(fig)
66
- return buf
67
 
68
- # ─── SDXL Text Encoding ───────────────────────────────────────
69
- def encode_sdxl_prompt(prompt, negative_prompt=""):
70
- """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
71
 
72
  # Tokenize for both encoders
73
  tokens_l = pipe.tokenizer(
74
- prompt,
75
- padding="max_length",
76
- max_length=77,
77
- truncation=True,
78
- return_tensors="pt"
79
  ).input_ids.to(device)
80
 
81
  tokens_g = pipe.tokenizer_2(
82
- prompt,
83
- padding="max_length",
84
- max_length=77,
85
- truncation=True,
86
- return_tensors="pt"
87
  ).input_ids.to(device)
88
 
89
- # Negative prompts
90
  neg_tokens_l = pipe.tokenizer(
91
- negative_prompt,
92
- padding="max_length",
93
- max_length=77,
94
- truncation=True,
95
- return_tensors="pt"
96
  ).input_ids.to(device)
97
 
98
  neg_tokens_g = pipe.tokenizer_2(
99
- negative_prompt,
100
- padding="max_length",
101
- max_length=77,
102
- truncation=True,
103
- return_tensors="pt"
104
  ).input_ids.to(device)
105
 
106
  with torch.no_grad():
107
- # CLIP-L embeddings (768d) - works fine
108
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
109
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
110
 
111
- # CLIP-G embeddings (1280d) - [0] is pooled, [1] is sequence (opposite of CLIP-L)
112
  clip_g_output = pipe.text_encoder_2(tokens_g)
113
  clip_g_embeds = clip_g_output[1] # sequence embeddings
 
114
 
115
  neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
116
- neg_clip_g_embeds = neg_clip_g_output[1] # sequence embeddings
117
-
118
- # Pooled embeddings for SDXL
119
- pooled_embeds = clip_g_output[0] # pooled embeddings
120
- neg_pooled_embeds = neg_clip_g_output[0] # pooled embeddings
121
 
122
  return {
123
  "clip_l": clip_l_embeds,
@@ -128,23 +113,16 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
128
  "neg_pooled": neg_pooled_embeds
129
  }
130
 
131
- # ─── Inference ────────────────────────────────────────────────
132
- @torch.no_grad()
133
- def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
- use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
-
136
- # ─── Inference ────────────────────────────────────────────
137
  @spaces.GPU
138
- @torch.no_grad()
139
  def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
140
  use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
141
 
142
- # Initialize device and models inside GPU context
143
  global t5_tok, t5_mod, pipe
144
  device = torch.device("cuda")
145
  dtype = torch.float16
146
 
147
- # Load models if not already loaded
148
  if t5_tok is None:
149
  t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
150
  t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
@@ -157,7 +135,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
157
  use_safetensors=True
158
  ).to(device)
159
 
160
- # Set seed for reproducibility
161
  if seed != -1:
162
  torch.manual_seed(seed)
163
  np.random.seed(seed)
@@ -166,67 +144,62 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
166
  if scheduler_name in SCHEDULERS:
167
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
168
 
169
- # Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
170
  t5_ids = t5_tok(
171
- prompt,
172
- return_tensors="pt",
173
- padding="max_length",
174
- max_length=77,
175
- truncation=True
176
  ).input_ids.to(device)
177
  t5_seq = t5_mod(t5_ids).last_hidden_state
178
 
179
- # Get proper SDXL CLIP embeddings
180
  clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
181
 
182
- # Debug shapes
183
- print(f"T5 seq shape: {t5_seq.shape}")
184
- print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
185
- print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
186
-
187
- # Load adapters
188
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
189
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
190
 
191
  # Apply CLIP-L adapter
192
  if adapter_l is not None:
193
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
 
 
194
  gate_l_scaled = gate_l * gate_prob
195
  delta_l_final = delta_l * strength * gate_l_scaled
196
- clip_l_mod = clip_embeds["clip_l"] + delta_l_final
 
197
  if use_anchor:
198
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
199
  if noise > 0:
200
  clip_l_mod += torch.randn_like(clip_l_mod) * noise
201
  else:
202
  clip_l_mod = clip_embeds["clip_l"]
203
  delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
204
  gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
205
- g_pred_l = torch.tensor(0.0)
206
- tau_l = torch.tensor(0.0)
207
 
208
  # Apply CLIP-G adapter
209
  if adapter_g is not None:
210
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
 
 
211
  gate_g_scaled = gate_g * gate_prob
212
  delta_g_final = delta_g * strength * gate_g_scaled
213
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final
 
214
  if use_anchor:
215
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
216
  if noise > 0:
217
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
218
  else:
219
  clip_g_mod = clip_embeds["clip_g"]
220
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
221
  gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
222
- g_pred_g = torch.tensor(0.0)
223
- tau_g = torch.tensor(0.0)
224
 
225
- # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
226
- prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
227
- neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
228
 
229
- # Generate image with proper SDXL parameters
230
  image = pipe(
231
  prompt_embeds=prompt_embeds,
232
  pooled_prompt_embeds=clip_embeds["pooled"],
@@ -236,69 +209,72 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
236
  guidance_scale=cfg_scale,
237
  width=width,
238
  height=height,
239
- num_images_per_prompt=1, # Explicitly set this
240
  generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
241
  ).images[0]
242
 
243
- return (
244
- image,
245
- plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
246
- plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1), "Gate CLIP-L"),
247
- plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
248
- plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1), "Gate CLIP-G"),
249
- f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
250
- f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
251
- )
 
 
252
 
253
  # ─── Gradio Interface ─────────────────────────────────────────
254
- with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
255
- gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
256
- gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings")
257
-
258
- with gr.Row():
259
- with gr.Column(scale=1):
260
- # Prompts
261
- with gr.Group():
262
- gr.Markdown("### Prompts")
263
  prompt = gr.Textbox(
264
- label="Prompt",
265
  value="a futuristic control station with holographic displays",
266
- lines=3
 
267
  )
268
  negative_prompt = gr.Textbox(
269
  label="Negative Prompt",
270
  value="blurry, low quality, distorted",
271
- lines=2
 
272
  )
273
-
274
- # Adapters
275
- with gr.Group():
276
- gr.Markdown("### Adapters")
277
  adapter_l = gr.Dropdown(
278
- choices=["None"] + clip_l_opts,
279
  label="CLIP-L (768d) Adapter",
280
- value="None"
 
281
  )
282
  adapter_g = gr.Dropdown(
283
- choices=["None"] + clip_g_opts,
284
- label="CLIP-G (1280d) Adapter",
285
- value="None"
 
286
  )
287
-
288
- # Adapter Controls
289
- with gr.Group():
290
- gr.Markdown("### Adapter Controls")
291
  strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
292
  noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
293
  gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
294
- use_anchor = gr.Checkbox(label="Use Anchor", value=True)
295
-
296
- # Generation Settings
297
- with gr.Group():
298
- gr.Markdown("### Generation Settings")
299
  with gr.Row():
300
- steps = gr.Slider(1, 100, value=25, step=1, label="Steps")
301
- cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale")
302
 
303
  scheduler_name = gr.Dropdown(
304
  choices=list(SCHEDULERS.keys()),
@@ -311,57 +287,48 @@ with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
311
  height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
312
 
313
  seed = gr.Number(value=-1, label="Seed (-1 for random)")
 
 
314
 
315
- run_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
316
-
317
- with gr.Column(scale=1):
318
- # Output
319
- with gr.Group():
320
- gr.Markdown("### Generated Image")
321
- out_img = gr.Image(label="Result", height=400)
322
-
323
- # Visualizations
324
- with gr.Group():
325
- gr.Markdown("### Adapter Visualizations")
326
  with gr.Row():
327
- delta_l = gr.Image(label="Ξ” CLIP-L", height=200)
328
- gate_l = gr.Image(label="Gate CLIP-L", height=200)
329
  with gr.Row():
330
- delta_g = gr.Image(label="Ξ” CLIP-G", height=200)
331
- gate_g = gr.Image(label="Gate CLIP-G", height=200)
332
-
333
- # Stats
334
- with gr.Group():
335
- gr.Markdown("### Adapter Statistics")
336
- stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False)
337
- stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
338
-
339
- # Event handlers
340
- def process_adapters(adapter_l_val, adapter_g_val):
341
- # Convert "None" back to None for processing
342
- adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
343
- adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
344
- return adapter_l_processed, adapter_g_processed
345
-
346
- def run_inference(*args):
347
- # Process adapter selections
348
- adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
349
 
350
- # Call inference with processed adapters
351
- new_args = list(args)
352
- new_args[2] = adapter_l_processed
353
- new_args[3] = adapter_g_processed
 
 
 
354
 
355
- return infer(*new_args)
 
 
 
 
 
 
 
356
 
357
- run_btn.click(
358
- fn=run_inference,
359
- inputs=[
360
- prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
361
- use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
362
- ],
363
- outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
364
- )
365
 
 
366
  if __name__ == "__main__":
 
367
  demo.launch()
 
11
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
12
  from configs import T5_SHUNT_REPOS
13
 
14
+ # ─── Global Variables ─────────────────────────────────────────
 
 
 
 
 
15
  t5_tok = None
16
  t5_mod = None
17
  pipe = None
 
31
  config_l = T5_SHUNT_REPOS["clip_l"]["config"]
32
  config_g = T5_SHUNT_REPOS["clip_g"]["config"]
33
 
34
+ # ─── Helper Functions ─────────────────────────────────────────
35
+ def load_adapter(repo, filename, config, device):
36
+ from safetensors.torch import safe_open
 
 
37
  path = hf_hub_download(repo_id=repo, filename=filename)
38
 
39
  model = TwoStreamShuntAdapter(config).eval()
 
42
  for key in f.keys():
43
  tensors[key] = f.get_tensor(key)
44
  model.load_state_dict(tensors)
45
+ return model.to(device)
 
46
 
 
47
  def plot_heat(mat, title):
48
+ """Create heatmap visualization with proper shape handling"""
49
  import io
50
+
51
+ # Ensure we have a 2D array for visualization
52
+ if len(mat.shape) == 1:
53
+ mat = mat.reshape(1, -1)
54
+ elif len(mat.shape) == 3:
55
+ mat = mat.mean(axis=0)
56
+ elif len(mat.shape) > 3:
57
+ mat = mat.reshape(-1, mat.shape[-1])
58
+
59
+ fig, ax = plt.subplots(figsize=(8, 4), dpi=100)
60
+ im = ax.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
61
+ ax.set_title(title, fontsize=12, fontweight='bold')
62
+ ax.set_xlabel("Token Position")
63
+ ax.set_ylabel("Feature Dimension")
64
+ plt.colorbar(im, ax=ax, shrink=0.8)
65
+
66
  buf = io.BytesIO()
67
+ plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
68
  buf.seek(0)
69
+ pil_image = Image.open(buf)
70
  plt.close(fig)
71
+ return pil_image
72
 
73
+ def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
74
+ """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
 
75
 
76
  # Tokenize for both encoders
77
  tokens_l = pipe.tokenizer(
78
+ prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
 
 
 
 
79
  ).input_ids.to(device)
80
 
81
  tokens_g = pipe.tokenizer_2(
82
+ prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
 
 
 
 
83
  ).input_ids.to(device)
84
 
 
85
  neg_tokens_l = pipe.tokenizer(
86
+ negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
 
 
 
 
87
  ).input_ids.to(device)
88
 
89
  neg_tokens_g = pipe.tokenizer_2(
90
+ negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
 
 
 
 
91
  ).input_ids.to(device)
92
 
93
  with torch.no_grad():
94
+ # CLIP-L: [0] = sequence, [1] = pooled
95
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
96
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
97
 
98
+ # CLIP-G: [0] = pooled, [1] = sequence (different from CLIP-L!)
99
  clip_g_output = pipe.text_encoder_2(tokens_g)
100
  clip_g_embeds = clip_g_output[1] # sequence embeddings
101
+ pooled_embeds = clip_g_output[0] # pooled embeddings
102
 
103
  neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
104
+ neg_clip_g_embeds = neg_clip_g_output[1]
105
+ neg_pooled_embeds = neg_clip_g_output[0]
 
 
 
106
 
107
  return {
108
  "clip_l": clip_l_embeds,
 
113
  "neg_pooled": neg_pooled_embeds
114
  }
115
 
116
+ # ─── Main Inference Function ──────────────────────────────────
 
 
 
 
 
117
  @spaces.GPU
 
118
  def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
119
  use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
120
 
 
121
  global t5_tok, t5_mod, pipe
122
  device = torch.device("cuda")
123
  dtype = torch.float16
124
 
125
+ # Initialize models
126
  if t5_tok is None:
127
  t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
128
  t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
 
135
  use_safetensors=True
136
  ).to(device)
137
 
138
+ # Set seed
139
  if seed != -1:
140
  torch.manual_seed(seed)
141
  np.random.seed(seed)
 
144
  if scheduler_name in SCHEDULERS:
145
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
146
 
147
+ # Get T5 embeddings
148
  t5_ids = t5_tok(
149
+ prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
 
 
 
 
150
  ).input_ids.to(device)
151
  t5_seq = t5_mod(t5_ids).last_hidden_state
152
 
153
+ # Get CLIP embeddings
154
  clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
155
 
156
+ # Load and apply adapters
157
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None
158
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None
 
 
 
 
 
159
 
160
  # Apply CLIP-L adapter
161
  if adapter_l is not None:
162
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(
163
+ t5_seq.float(), clip_embeds["clip_l"].float()
164
+ )
165
  gate_l_scaled = gate_l * gate_prob
166
  delta_l_final = delta_l * strength * gate_l_scaled
167
+ clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
168
+
169
  if use_anchor:
170
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
171
  if noise > 0:
172
  clip_l_mod += torch.randn_like(clip_l_mod) * noise
173
  else:
174
  clip_l_mod = clip_embeds["clip_l"]
175
  delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
176
  gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
177
+ g_pred_l, tau_l = torch.tensor(0.0), torch.tensor(0.0)
 
178
 
179
  # Apply CLIP-G adapter
180
  if adapter_g is not None:
181
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(
182
+ t5_seq.float(), clip_embeds["clip_g"].float()
183
+ )
184
  gate_g_scaled = gate_g * gate_prob
185
  delta_g_final = delta_g * strength * gate_g_scaled
186
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
187
+
188
  if use_anchor:
189
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
190
  if noise > 0:
191
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
192
  else:
193
  clip_g_mod = clip_embeds["clip_g"]
194
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
195
  gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
196
+ g_pred_g, tau_g = torch.tensor(0.0), torch.tensor(0.0)
 
197
 
198
+ # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
199
+ prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
200
+ neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1)
201
 
202
+ # Generate image
203
  image = pipe(
204
  prompt_embeds=prompt_embeds,
205
  pooled_prompt_embeds=clip_embeds["pooled"],
 
209
  guidance_scale=cfg_scale,
210
  width=width,
211
  height=height,
212
+ num_images_per_prompt=1,
213
  generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
214
  ).images[0]
215
 
216
+ # Create visualizations
217
+ delta_l_viz = plot_heat(delta_l_final.squeeze().cpu().numpy(), "CLIP-L Delta Values")
218
+ gate_l_viz = plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-L Gate Activations")
219
+ delta_g_viz = plot_heat(delta_g_final.squeeze().cpu().numpy(), "CLIP-G Delta Values")
220
+ gate_g_viz = plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-G Gate Activations")
221
+
222
+ # Statistics
223
+ stats_l = f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}"
224
+ stats_g = f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
225
+
226
+ return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
227
 
228
  # ─── Gradio Interface ─────────────────────────────────────────
229
+ def create_interface():
230
+ with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
231
+ gr.Markdown("# 🧠 SDXL Dual Shunt Adapter")
232
+ gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*")
233
+
234
+ with gr.Row():
235
+ with gr.Column(scale=1):
236
+ # Prompts
237
+ gr.Markdown("### πŸ“ Prompts")
238
  prompt = gr.Textbox(
239
+ label="Prompt",
240
  value="a futuristic control station with holographic displays",
241
+ lines=3,
242
+ placeholder="Describe what you want to generate..."
243
  )
244
  negative_prompt = gr.Textbox(
245
  label="Negative Prompt",
246
  value="blurry, low quality, distorted",
247
+ lines=2,
248
+ placeholder="Describe what you want to avoid..."
249
  )
250
+
251
+ # Adapters
252
+ gr.Markdown("### βš™οΈ Adapters")
 
253
  adapter_l = gr.Dropdown(
254
+ choices=["None"] + clip_l_opts,
255
  label="CLIP-L (768d) Adapter",
256
+ value="None",
257
+ info="Choose adapter for CLIP-L embeddings"
258
  )
259
  adapter_g = gr.Dropdown(
260
+ choices=["None"] + clip_g_opts,
261
+ label="CLIP-G (1280d) Adapter",
262
+ value="None",
263
+ info="Choose adapter for CLIP-G embeddings"
264
  )
265
+
266
+ # Controls
267
+ gr.Markdown("### πŸŽ›οΈ Adapter Controls")
 
268
  strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
269
  noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
270
  gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
271
+ use_anchor = gr.Checkbox(label="Use Anchor Points", value=True)
272
+
273
+ # Generation Settings
274
+ gr.Markdown("### 🎨 Generation Settings")
 
275
  with gr.Row():
276
+ steps = gr.Slider(1, 50, value=25, step=1, label="Steps")
277
+ cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale")
278
 
279
  scheduler_name = gr.Dropdown(
280
  choices=list(SCHEDULERS.keys()),
 
287
  height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
288
 
289
  seed = gr.Number(value=-1, label="Seed (-1 for random)")
290
+
291
+ generate_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
292
 
293
+ with gr.Column(scale=1):
294
+ # Output
295
+ gr.Markdown("### πŸ–ΌοΈ Generated Image")
296
+ output_image = gr.Image(label="Result", height=400, show_label=False)
297
+
298
+ # Visualizations
299
+ gr.Markdown("### πŸ“Š Adapter Analysis")
 
 
 
 
300
  with gr.Row():
301
+ delta_l_img = gr.Image(label="CLIP-L Deltas", height=200)
302
+ gate_l_img = gr.Image(label="CLIP-L Gates", height=200)
303
  with gr.Row():
304
+ delta_g_img = gr.Image(label="CLIP-G Deltas", height=200)
305
+ gate_g_img = gr.Image(label="CLIP-G Gates", height=200)
306
+
307
+ # Statistics
308
+ gr.Markdown("### πŸ“ˆ Statistics")
309
+ stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False)
310
+ stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
+ # Event handler
313
+ def run_generation(*args):
314
+ # Process adapter selections
315
+ processed_args = list(args)
316
+ processed_args[2] = None if args[2] == "None" else args[2] # adapter_l
317
+ processed_args[3] = None if args[3] == "None" else args[3] # adapter_g
318
+ return infer(*processed_args)
319
 
320
+ generate_btn.click(
321
+ fn=run_generation,
322
+ inputs=[
323
+ prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
324
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
325
+ ],
326
+ outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text]
327
+ )
328
 
329
+ return demo
 
 
 
 
 
 
 
330
 
331
+ # ─── Launch ────────────────────────────────────────────────────
332
  if __name__ == "__main__":
333
+ demo = create_interface()
334
  demo.launch()