AbstractPhil commited on
Commit
1e5ce4d
Β·
1 Parent(s): 7b42604
__pycache__/two_stream_shunt_adapter.cpython-310.pyc CHANGED
Binary files a/__pycache__/two_stream_shunt_adapter.cpython-310.pyc and b/__pycache__/two_stream_shunt_adapter.cpython-310.pyc differ
 
app.py CHANGED
@@ -3,25 +3,35 @@ import gradio as gr
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from transformers import T5Tokenizer, T5EncoderModel
6
- from diffusers import DiffusionPipeline
7
- from safetensors.torch import safe_open
8
  from huggingface_hub import hf_hub_download
9
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
10
- from adapter_config import T5_SHUNT_REPOS
11
 
12
  # ─── Device & Model Setup ─────────────────────────────────────
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
 
 
16
  t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
17
  t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
18
 
19
- pipe = DiffusionPipeline.from_pretrained(
 
20
  "stabilityai/stable-diffusion-xl-base-1.0",
21
  torch_dtype=dtype,
22
- variant="fp16" if dtype == torch.float16 else None
 
23
  ).to(device)
24
 
 
 
 
 
 
 
 
25
  # ─── Adapter Configs ──────────────────────────────────────────
26
  clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
27
  clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
@@ -31,8 +41,11 @@ config_l = T5_SHUNT_REPOS["clip_l"]["config"]
31
  config_g = T5_SHUNT_REPOS["clip_g"]["config"]
32
 
33
  # ─── Loader ───────────────────────────────────────────────────
 
 
34
  def load_adapter(repo, filename, config):
35
  path = hf_hub_download(repo_id=repo, filename=filename)
 
36
  model = TwoStreamShuntAdapter(config).eval()
37
  tensors = {}
38
  with safe_open(path, framework="pt", device="cpu") as f:
@@ -42,103 +55,277 @@ def load_adapter(repo, filename, config):
42
  model.to(device)
43
  return model
44
 
45
- # ─── Inference ────────────────────────────────────────────────
46
- @torch.no_grad()
47
- def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
48
- adapter_list = []
49
- # Load adapters with config
50
- adapter_list.append({
51
- "adapter": load_adapter(repo_l, adapter_l_file, config_l),
52
- "config": config_l
53
- })
54
- adapter_list.append({
55
- "adapter": load_adapter(repo_g, adapter_g_file, config_g),
56
- "config": config_g
57
- })
58
-
59
- # Encode prompt via T5
60
- t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
61
- t5_seq = t5_mod(t5_ids).last_hidden_state # (B, L, 768)
62
-
63
- # Encode prompt via SDXL normally to get CLIP-L and CLIP-G outputs
64
- prompt_embeds, pooled_prompt_embeds = pipe._encode_prompt(
65
- prompt=prompt,
66
- device=device,
67
- num_images_per_prompt=1,
68
- do_classifier_free_guidance=False,
69
- )
70
-
71
- total_dim = prompt_embeds.shape[-1]
72
- cond_tensor = prompt_embeds.clone()
73
-
74
- for adapter_info in adapter_list:
75
- adapter_model = adapter_info["adapter"]
76
- adapter_config = adapter_info["config"]
77
- clip_dim = adapter_config["clip"]["hidden_size"]
78
-
79
- if clip_dim == 768:
80
- clip_slice = cond_tensor[:, :, :768]
81
- slice_start, slice_end = 0, 768
82
- elif clip_dim == 1280:
83
- clip_slice = cond_tensor[:, :, 768:2048] if total_dim >= 2048 else cond_tensor[:, :, 768:]
84
- slice_start, slice_end = 768, 2048
85
- else:
86
- continue
87
-
88
- anchor, delta_mean_adapter, log_sigma_adapter, _, _, _, g_pred_adapter, gate_adapter = adapter_model(t5_seq, clip_slice)
89
- gate = gate_adapter * gate_prob
90
- delta = (delta_mean_adapter + 0.0) * strength * gate
91
 
92
- if delta.shape[1] != clip_slice.shape[1]:
93
- delta = torch.nn.functional.interpolate(
94
- delta.transpose(1, 2),
95
- size=clip_slice.size(1),
96
- mode="nearest"
97
- ).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if use_anchor:
100
- clip_slice = clip_slice * (1 - gate) + anchor * gate
101
-
102
  if noise > 0:
103
- clip_slice = clip_slice + torch.randn_like(clip_slice) * noise
104
-
105
- cond_tensor[:, :, slice_start:slice_end] = (clip_slice + delta).type_as(cond_tensor)
106
-
107
- pooled_embed = cond_tensor.mean(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  image = pipe(
109
- prompt_embeds=cond_tensor,
110
- pooled_prompt_embeds=pooled_embed,
111
- negative_prompt_embeds=torch.zeros_like(cond_tensor),
112
- negative_pooled_prompt_embeds=torch.zeros_like(pooled_embed),
113
- num_inference_steps=20,
114
- guidance_scale=5.0
 
 
 
115
  ).images[0]
 
 
 
 
 
 
 
 
 
 
116
 
117
- return image
118
-
119
- # ─── Gradio App ───────────────────────────────────────────────
120
- with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
121
- gr.Markdown("# 🧠 Dual Shunt Adapter β€’ SDXL Inference")
122
-
123
  with gr.Row():
124
- with gr.Column():
125
- prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
126
- adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
127
- adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
128
- strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
129
- noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
130
- gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
131
- use_anchor = gr.Checkbox(label="Use Anchor", value=True)
132
- run_btn = gr.Button("Run")
133
-
134
- with gr.Column():
135
- out_img = gr.Image(label="Generated Image")
136
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  run_btn.click(
138
- fn=infer,
139
- inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
140
- outputs=out_img
 
 
 
141
  )
142
 
143
  if __name__ == "__main__":
144
- demo.launch(share=True)
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from transformers import T5Tokenizer, T5EncoderModel
6
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
7
+ from safetensors.torch import load_file
8
  from huggingface_hub import hf_hub_download
9
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
10
+ from configs import T5_SHUNT_REPOS
11
 
12
  # ─── Device & Model Setup ─────────────────────────────────────
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
 
16
+ # T5 Model for semantic understanding
17
  t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
18
  t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
19
 
20
+ # SDXL Pipeline with proper text encoders
21
+ pipe = StableDiffusionXLPipeline.from_pretrained(
22
  "stabilityai/stable-diffusion-xl-base-1.0",
23
  torch_dtype=dtype,
24
+ variant="fp16" if dtype == torch.float16 else None,
25
+ use_safetensors=True
26
  ).to(device)
27
 
28
+ # Available schedulers
29
+ SCHEDULERS = {
30
+ "DPM++ 2M": DPMSolverMultistepScheduler,
31
+ "DDIM": DDIMScheduler,
32
+ "Euler": EulerDiscreteScheduler,
33
+ }
34
+
35
  # ─── Adapter Configs ──────────────────────────────────────────
36
  clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
37
  clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
 
41
  config_g = T5_SHUNT_REPOS["clip_g"]["config"]
42
 
43
  # ─── Loader ───────────────────────────────────────────────────
44
+ from safetensors.torch import safe_open
45
+
46
  def load_adapter(repo, filename, config):
47
  path = hf_hub_download(repo_id=repo, filename=filename)
48
+
49
  model = TwoStreamShuntAdapter(config).eval()
50
  tensors = {}
51
  with safe_open(path, framework="pt", device="cpu") as f:
 
55
  model.to(device)
56
  return model
57
 
58
+ # ─── Visualization ────────────────────────────────────────────
59
+ def plot_heat(mat, title):
60
+ import io
61
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
62
+ im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
63
+ ax.set_title(title)
64
+ plt.colorbar(im, ax=ax)
65
+ buf = io.BytesIO()
66
+ plt.savefig(buf, format="png", bbox_inches='tight')
67
+ buf.seek(0)
68
+ plt.close(fig)
69
+ return buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # ─── SDXL Text Encoding ───────────────────────────────────────
72
+ def encode_sdxl_prompt(prompt, negative_prompt=""):
73
+ """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
74
+
75
+ # Tokenize for both encoders
76
+ tokens_l = pipe.tokenizer(
77
+ prompt,
78
+ padding="max_length",
79
+ max_length=77,
80
+ truncation=True,
81
+ return_tensors="pt"
82
+ ).input_ids.to(device)
83
+
84
+ tokens_g = pipe.tokenizer_2(
85
+ prompt,
86
+ padding="max_length",
87
+ max_length=77,
88
+ truncation=True,
89
+ return_tensors="pt"
90
+ ).input_ids.to(device)
91
+
92
+ # Negative prompts
93
+ neg_tokens_l = pipe.tokenizer(
94
+ negative_prompt,
95
+ padding="max_length",
96
+ max_length=77,
97
+ truncation=True,
98
+ return_tensors="pt"
99
+ ).input_ids.to(device)
100
+
101
+ neg_tokens_g = pipe.tokenizer_2(
102
+ negative_prompt,
103
+ padding="max_length",
104
+ max_length=77,
105
+ truncation=True,
106
+ return_tensors="pt"
107
+ ).input_ids.to(device)
108
+
109
+ with torch.no_grad():
110
+ # CLIP-L embeddings (768d)
111
+ clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
+ neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
+
114
+ # CLIP-G embeddings (1280d)
115
+ clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
+ neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
+
118
+ # Pooled embeddings for SDXL
119
+ pooled_embeds = pipe.text_encoder_2(tokens_g)[1]
120
+ neg_pooled_embeds = pipe.text_encoder_2(neg_tokens_g)[1]
121
+
122
+ return {
123
+ "clip_l": clip_l_embeds,
124
+ "clip_g": clip_g_embeds,
125
+ "neg_clip_l": neg_clip_l_embeds,
126
+ "neg_clip_g": neg_clip_g_embeds,
127
+ "pooled": pooled_embeds,
128
+ "neg_pooled": neg_pooled_embeds
129
+ }
130
 
131
+ # ─── Inference ────────────────────────────────────────────────
132
+ @torch.no_grad()
133
+ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
+
136
+ # Set seed for reproducibility
137
+ if seed != -1:
138
+ torch.manual_seed(seed)
139
+ np.random.seed(seed)
140
+
141
+ # Set scheduler
142
+ if scheduler_name in SCHEDULERS:
143
+ pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
+
145
+ # Get T5 embeddings for semantic understanding
146
+ t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
147
+ t5_seq = t5_mod(t5_ids).last_hidden_state
148
+
149
+ # Get proper SDXL CLIP embeddings
150
+ clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
151
+
152
+ # Load adapters
153
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
154
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
155
+
156
+ # Apply CLIP-L adapter
157
+ if adapter_l is not None:
158
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
159
+ gate_l_scaled = gate_l * gate_prob
160
+ delta_l_final = delta_l * strength * gate_l_scaled
161
+ clip_l_mod = clip_embeds["clip_l"] + delta_l_final
162
  if use_anchor:
163
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
 
164
  if noise > 0:
165
+ clip_l_mod += torch.randn_like(clip_l_mod) * noise
166
+ else:
167
+ clip_l_mod = clip_embeds["clip_l"]
168
+ delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
169
+ gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
170
+ g_pred_l = torch.tensor(0.0)
171
+ tau_l = torch.tensor(0.0)
172
+
173
+ # Apply CLIP-G adapter
174
+ if adapter_g is not None:
175
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
176
+ gate_g_scaled = gate_g * gate_prob
177
+ delta_g_final = delta_g * strength * gate_g_scaled
178
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final
179
+ if use_anchor:
180
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
181
+ if noise > 0:
182
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
183
+ else:
184
+ clip_g_mod = clip_embeds["clip_g"]
185
+ delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
186
+ gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
187
+ g_pred_g = torch.tensor(0.0)
188
+ tau_g = torch.tensor(0.0)
189
+
190
+ # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
191
+ prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
192
+ neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
193
+
194
+ # Generate image with proper SDXL parameters
195
  image = pipe(
196
+ prompt_embeds=prompt_embeds,
197
+ pooled_prompt_embeds=clip_embeds["pooled"],
198
+ negative_prompt_embeds=neg_embeds,
199
+ negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
200
+ num_inference_steps=steps,
201
+ guidance_scale=cfg_scale,
202
+ width=width,
203
+ height=height,
204
+ generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
205
  ).images[0]
206
+
207
+ return (
208
+ image,
209
+ plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
210
+ plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
211
+ plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
212
+ plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
213
+ f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
214
+ f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
215
+ )
216
 
217
+ # ─── Gradio Interface ─────────────────────────────────────────
218
+ with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
220
+ gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings")
221
+
 
222
  with gr.Row():
223
+ with gr.Column(scale=1):
224
+ # Prompts
225
+ with gr.Group():
226
+ gr.Markdown("### Prompts")
227
+ prompt = gr.Textbox(
228
+ label="Prompt",
229
+ value="a futuristic control station with holographic displays",
230
+ lines=3
231
+ )
232
+ negative_prompt = gr.Textbox(
233
+ label="Negative Prompt",
234
+ value="blurry, low quality, distorted",
235
+ lines=2
236
+ )
237
+
238
+ # Adapters
239
+ with gr.Group():
240
+ gr.Markdown("### Adapters")
241
+ adapter_l = gr.Dropdown(
242
+ choices=["None"] + clip_l_opts,
243
+ label="CLIP-L (768d) Adapter",
244
+ value="None"
245
+ )
246
+ adapter_g = gr.Dropdown(
247
+ choices=["None"] + clip_g_opts,
248
+ label="CLIP-G (1280d) Adapter",
249
+ value="None"
250
+ )
251
+
252
+ # Adapter Controls
253
+ with gr.Group():
254
+ gr.Markdown("### Adapter Controls")
255
+ strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
256
+ noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
257
+ gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
258
+ use_anchor = gr.Checkbox(label="Use Anchor", value=True)
259
+
260
+ # Generation Settings
261
+ with gr.Group():
262
+ gr.Markdown("### Generation Settings")
263
+ with gr.Row():
264
+ steps = gr.Slider(1, 100, value=25, step=1, label="Steps")
265
+ cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale")
266
+
267
+ scheduler_name = gr.Dropdown(
268
+ choices=list(SCHEDULERS.keys()),
269
+ value="DPM++ 2M",
270
+ label="Scheduler"
271
+ )
272
+
273
+ with gr.Row():
274
+ width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
275
+ height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
276
+
277
+ seed = gr.Number(value=-1, label="Seed (-1 for random)")
278
+
279
+ run_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
280
+
281
+ with gr.Column(scale=1):
282
+ # Output
283
+ with gr.Group():
284
+ gr.Markdown("### Generated Image")
285
+ out_img = gr.Image(label="Result", height=400)
286
+
287
+ # Visualizations
288
+ with gr.Group():
289
+ gr.Markdown("### Adapter Visualizations")
290
+ with gr.Row():
291
+ delta_l = gr.Image(label="Ξ” CLIP-L", height=200)
292
+ gate_l = gr.Image(label="Gate CLIP-L", height=200)
293
+ with gr.Row():
294
+ delta_g = gr.Image(label="Ξ” CLIP-G", height=200)
295
+ gate_g = gr.Image(label="Gate CLIP-G", height=200)
296
+
297
+ # Stats
298
+ with gr.Group():
299
+ gr.Markdown("### Adapter Statistics")
300
+ stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False)
301
+ stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
302
+
303
+ # Event handlers
304
+ def process_adapters(adapter_l_val, adapter_g_val):
305
+ # Convert "None" back to None for processing
306
+ adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
307
+ adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
308
+ return adapter_l_processed, adapter_g_processed
309
+
310
+ def run_inference(*args):
311
+ # Process adapter selections
312
+ adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
313
+
314
+ # Call inference with processed adapters
315
+ new_args = list(args)
316
+ new_args[2] = adapter_l_processed
317
+ new_args[3] = adapter_g_processed
318
+
319
+ return infer(*new_args)
320
+
321
  run_btn.click(
322
+ fn=run_inference,
323
+ inputs=[
324
+ prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
325
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
326
+ ],
327
+ outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
328
  )
329
 
330
  if __name__ == "__main__":
331
+ demo.launch(share=True)
two_stream_shunt_adapter.py CHANGED
@@ -1,331 +1,123 @@
1
  import torch
2
- import gradio as gr
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- from transformers import T5Tokenizer, T5EncoderModel
6
- from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
7
- from safetensors.torch import load_file
8
- from huggingface_hub import hf_hub_download
9
- from two_stream_shunt_adapter import TwoStreamShuntAdapter
10
- from configs import T5_SHUNT_REPOS
11
 
12
- # ─── Device & Model Setup ─────────────────────────────────────
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
15
 
16
- # T5 Model for semantic understanding
17
- t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
18
- t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
 
 
 
19
 
20
- # SDXL Pipeline with proper text encoders
21
- pipe = StableDiffusionXLPipeline.from_pretrained(
22
- "stabilityai/stable-diffusion-xl-base-1.0",
23
- torch_dtype=dtype,
24
- variant="fp16" if dtype == torch.float16 else None,
25
- use_safetensors=True
26
- ).to(device)
 
 
 
 
27
 
28
- # Available schedulers
29
- SCHEDULERS = {
30
- "DPM++ 2M": DPMSolverMultistepScheduler,
31
- "DDIM": DDIMScheduler,
32
- "Euler": EulerDiscreteScheduler,
33
- }
34
 
35
- # ─── Adapter Configs ──────────────────────────────────────────
36
- clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
37
- clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
38
- repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
39
- repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
40
- config_l = T5_SHUNT_REPOS["clip_l"]["config"]
41
- config_g = T5_SHUNT_REPOS["clip_g"]["config"]
 
 
 
 
 
 
 
42
 
43
- # ─── Loader ───────────────────────────────────────────────────
44
- from safetensors.torch import safe_open
 
45
 
46
- def load_adapter(repo, filename, config):
47
- path = hf_hub_download(repo_id=repo, filename=filename)
48
-
49
- model = TwoStreamShuntAdapter(config).eval()
50
- tensors = {}
51
- with safe_open(path, framework="pt", device="cpu") as f:
52
- for key in f.keys():
53
- tensors[key] = f.get_tensor(key)
54
- model.load_state_dict(tensors)
55
- model.to(device)
56
- return model
57
 
58
- # ─── Visualization ────────────────────────────────────────────
59
- def plot_heat(mat, title):
60
- import io
61
- fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
62
- im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
63
- ax.set_title(title)
64
- plt.colorbar(im, ax=ax)
65
- buf = io.BytesIO()
66
- plt.savefig(buf, format="png", bbox_inches='tight')
67
- buf.seek(0)
68
- plt.close(fig)
69
- return buf
70
 
71
- # ─── SDXL Text Encoding ───────────────────────────────────────
72
- def encode_sdxl_prompt(prompt, negative_prompt=""):
73
- """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
74
-
75
- # Tokenize for both encoders
76
- tokens_l = pipe.tokenizer(
77
- prompt,
78
- padding="max_length",
79
- max_length=77,
80
- truncation=True,
81
- return_tensors="pt"
82
- ).input_ids.to(device)
83
-
84
- tokens_g = pipe.tokenizer_2(
85
- prompt,
86
- padding="max_length",
87
- max_length=77,
88
- truncation=True,
89
- return_tensors="pt"
90
- ).input_ids.to(device)
91
-
92
- # Negative prompts
93
- neg_tokens_l = pipe.tokenizer(
94
- negative_prompt,
95
- padding="max_length",
96
- max_length=77,
97
- truncation=True,
98
- return_tensors="pt"
99
- ).input_ids.to(device)
100
-
101
- neg_tokens_g = pipe.tokenizer_2(
102
- negative_prompt,
103
- padding="max_length",
104
- max_length=77,
105
- truncation=True,
106
- return_tensors="pt"
107
- ).input_ids.to(device)
108
-
109
- with torch.no_grad():
110
- # CLIP-L embeddings (768d)
111
- clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
- neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
-
114
- # CLIP-G embeddings (1280d)
115
- clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
- neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
-
118
- # Pooled embeddings for SDXL
119
- pooled_embeds = pipe.text_encoder_2(tokens_g)[1]
120
- neg_pooled_embeds = pipe.text_encoder_2(neg_tokens_g)[1]
121
-
122
- return {
123
- "clip_l": clip_l_embeds,
124
- "clip_g": clip_g_embeds,
125
- "neg_clip_l": neg_clip_l_embeds,
126
- "neg_clip_g": neg_clip_g_embeds,
127
- "pooled": pooled_embeds,
128
- "neg_pooled": neg_pooled_embeds
129
- }
130
 
131
- # ─── Inference ────────────────────────────────────────────────
132
- @torch.no_grad()
133
- def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
- use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
-
136
- # Set seed for reproducibility
137
- if seed != -1:
138
- torch.manual_seed(seed)
139
- np.random.seed(seed)
140
-
141
- # Set scheduler
142
- if scheduler_name in SCHEDULERS:
143
- pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
-
145
- # Get T5 embeddings for semantic understanding
146
- t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
147
- t5_seq = t5_mod(t5_ids).last_hidden_state
148
-
149
- # Get proper SDXL CLIP embeddings
150
- clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
151
-
152
- # Load adapters
153
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
154
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
155
-
156
- # Apply CLIP-L adapter
157
- if adapter_l is not None:
158
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
159
- gate_l_scaled = gate_l * gate_prob
160
- delta_l_final = delta_l * strength * gate_l_scaled
161
- clip_l_mod = clip_embeds["clip_l"] + delta_l_final
162
- if use_anchor:
163
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
164
- if noise > 0:
165
- clip_l_mod += torch.randn_like(clip_l_mod) * noise
166
- else:
167
- clip_l_mod = clip_embeds["clip_l"]
168
- delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
169
- gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
170
- g_pred_l = torch.tensor(0.0)
171
- tau_l = torch.tensor(0.0)
172
-
173
- # Apply CLIP-G adapter
174
- if adapter_g is not None:
175
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
176
- gate_g_scaled = gate_g * gate_prob
177
- delta_g_final = delta_g * strength * gate_g_scaled
178
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final
179
- if use_anchor:
180
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
181
- if noise > 0:
182
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
183
- else:
184
- clip_g_mod = clip_embeds["clip_g"]
185
- delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
186
- gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
187
- g_pred_g = torch.tensor(0.0)
188
- tau_g = torch.tensor(0.0)
189
-
190
- # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
191
- prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
192
- neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
193
-
194
- # Generate image with proper SDXL parameters
195
- image = pipe(
196
- prompt_embeds=prompt_embeds,
197
- pooled_prompt_embeds=clip_embeds["pooled"],
198
- negative_prompt_embeds=neg_embeds,
199
- negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
200
- num_inference_steps=steps,
201
- guidance_scale=cfg_scale,
202
- width=width,
203
- height=height,
204
- generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
205
- ).images[0]
206
-
207
- return (
208
- image,
209
- plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
210
- plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
211
- plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
212
- plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
213
- f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
214
- f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
215
- )
216
 
217
- # ─── Gradio Interface ─────────────────────────────────────────
218
- with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
219
- gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
220
- gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings")
221
-
222
- with gr.Row():
223
- with gr.Column(scale=1):
224
- # Prompts
225
- with gr.Group():
226
- gr.Markdown("### Prompts")
227
- prompt = gr.Textbox(
228
- label="Prompt",
229
- value="a futuristic control station with holographic displays",
230
- lines=3
231
- )
232
- negative_prompt = gr.Textbox(
233
- label="Negative Prompt",
234
- value="blurry, low quality, distorted",
235
- lines=2
236
- )
237
-
238
- # Adapters
239
- with gr.Group():
240
- gr.Markdown("### Adapters")
241
- adapter_l = gr.Dropdown(
242
- choices=["None"] + clip_l_opts,
243
- label="CLIP-L (768d) Adapter",
244
- value="None"
245
- )
246
- adapter_g = gr.Dropdown(
247
- choices=["None"] + clip_g_opts,
248
- label="CLIP-G (1280d) Adapter",
249
- value="None"
250
- )
251
-
252
- # Adapter Controls
253
- with gr.Group():
254
- gr.Markdown("### Adapter Controls")
255
- strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
256
- noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
257
- gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
258
- use_anchor = gr.Checkbox(label="Use Anchor", value=True)
259
-
260
- # Generation Settings
261
- with gr.Group():
262
- gr.Markdown("### Generation Settings")
263
- with gr.Row():
264
- steps = gr.Slider(1, 100, value=25, step=1, label="Steps")
265
- cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale")
266
-
267
- scheduler_name = gr.Dropdown(
268
- choices=list(SCHEDULERS.keys()),
269
- value="DPM++ 2M",
270
- label="Scheduler"
271
- )
272
-
273
- with gr.Row():
274
- width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
275
- height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
276
-
277
- seed = gr.Number(value=-1, label="Seed (-1 for random)")
278
-
279
- run_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
280
-
281
- with gr.Column(scale=1):
282
- # Output
283
- with gr.Group():
284
- gr.Markdown("### Generated Image")
285
- out_img = gr.Image(label="Result", height=400)
286
-
287
- # Visualizations
288
- with gr.Group():
289
- gr.Markdown("### Adapter Visualizations")
290
- with gr.Row():
291
- delta_l = gr.Image(label="Ξ” CLIP-L", height=200)
292
- gate_l = gr.Image(label="Gate CLIP-L", height=200)
293
- with gr.Row():
294
- delta_g = gr.Image(label="Ξ” CLIP-G", height=200)
295
- gate_g = gr.Image(label="Gate CLIP-G", height=200)
296
-
297
- # Stats
298
- with gr.Group():
299
- gr.Markdown("### Adapter Statistics")
300
- stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False)
301
- stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
302
-
303
- # Event handlers
304
- def process_adapters(adapter_l_val, adapter_g_val):
305
- # Convert "None" back to None for processing
306
- adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
307
- adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
308
- return adapter_l_processed, adapter_g_processed
309
-
310
- def run_inference(*args):
311
- # Process adapter selections
312
- adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
313
-
314
- # Call inference with processed adapters
315
- new_args = list(args)
316
- new_args[2] = adapter_l_processed
317
- new_args[3] = adapter_g_processed
318
-
319
- return infer(*new_args)
320
-
321
- run_btn.click(
322
- fn=run_inference,
323
- inputs=[
324
- prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
325
- use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
326
- ],
327
- outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
328
- )
329
 
330
- if __name__ == "__main__":
331
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
 
 
 
 
 
 
 
4
 
5
+ # ─── Residual Pocket Block ───────────────────────────────────
6
+ class BottleneckResBlock(nn.Module):
7
+ def __init__(self, dim, kernel=3, dropout=0.1):
8
+ super().__init__()
9
+ self.norm = nn.LayerNorm(dim)
10
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
11
+ self.proj = nn.Sequential(
12
+ nn.Linear(dim, dim * 2),
13
+ nn.GELU(),
14
+ nn.Linear(dim * 2, dim),
15
+ nn.Dropout(dropout)
16
+ )
17
 
18
+ def forward(self, x):
19
+ residual = x
20
+ x = self.norm(x)
21
+ x = x.transpose(1, 2)
22
+ x = self.conv(x).transpose(1, 2)
23
+ return residual + self.proj(x)
24
 
25
+ # ─── Two Stream Shunt Adapter ──────────────────────────────────────
26
+ class TwoStreamShuntAdapter(nn.Module):
27
+ def __init__(self, config: dict):
28
+ super().__init__()
29
+ self.config = config
30
+ self.t5_dim = config["t5"]["hidden_size"]
31
+ self.clip_dim = config["clip"]["hidden_size"]
32
+ self.bneck = config["bottleneck"]
33
+ self.heads = config["heads"]
34
+ self.tau_init = config["tau_init"]
35
+ self.max_guidance = config["max_guidance"]
36
 
37
+ use_norm = config.get("layer_norm", True)
38
+ use_do = config.get("use_dropout", True)
39
+ do_p = config.get("dropout", 0.1)
40
+ proj_depth = config.get("proj_layers", 2)
 
 
41
 
42
+ def build_projection(input_dim, output_dim):
43
+ layers = []
44
+ last_dim = input_dim
45
+ if use_norm:
46
+ layers.append(nn.LayerNorm(last_dim))
47
+ for i in range(proj_depth):
48
+ next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
49
+ layers.append(nn.Linear(last_dim, next_dim))
50
+ layers.append(nn.GELU())
51
+ if use_do:
52
+ layers.append(nn.Dropout(do_p))
53
+ last_dim = next_dim
54
+ layers.append(nn.Linear(last_dim, output_dim))
55
+ return nn.Sequential(*layers)
56
 
57
+ # Projections
58
+ self.proj_t5 = build_projection(self.t5_dim, self.bneck)
59
+ self.proj_clip = build_projection(self.clip_dim, self.bneck)
60
 
61
+ # Attention
62
+ self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
63
+ self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
64
+ self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
 
 
 
 
 
 
 
65
 
66
+ # Residual Pocket
67
+ self.pocket_blocks = nn.Sequential(
68
+ BottleneckResBlock(self.bneck, dropout=do_p),
69
+ BottleneckResBlock(self.bneck, dropout=do_p)
70
+ )
 
 
 
 
 
 
 
71
 
72
+ # Fuse
73
+ self.fuse = nn.Sequential(
74
+ nn.LayerNorm(2 * self.bneck),
75
+ nn.Linear(2 * self.bneck, self.bneck * 2),
76
+ nn.GELU(),
77
+ nn.Linear(self.bneck * 2, self.bneck)
78
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Output Projections
81
+ self.anchor_proj = build_projection(self.bneck, self.clip_dim)
82
+ self.delta_proj = build_projection(self.bneck, self.clip_dim)
83
+ self.logsig_proj = build_projection(self.bneck, self.clip_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ self.gate_proj = nn.Sequential(
86
+ nn.LayerNorm(self.bneck),
87
+ nn.Linear(self.bneck, self.bneck),
88
+ nn.GELU(),
89
+ nn.Linear(self.bneck, 1),
90
+ nn.Tanh(),
91
+ nn.Sigmoid()
92
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ self.guidance_proj = nn.Sequential(
95
+ nn.LayerNorm(self.bneck),
96
+ nn.Linear(self.bneck, 1),
97
+ nn.Sigmoid()
98
+ )
99
+
100
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
101
+ if self.config.get("assert_input_dims", True):
102
+ assert t5_seq.size(-1) == self.t5_dim
103
+ assert clip_seq.size(-1) == self.clip_dim
104
+
105
+ t5_b = self.proj_t5(t5_seq)
106
+ clip_b = self.proj_clip(clip_seq)
107
+
108
+ t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
109
+ c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
110
+
111
+ pocket = self.pocket_blocks(t2c)
112
+
113
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
114
+ h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
115
+
116
+ anchor = self.anchor_proj(h)
117
+ delta = self.delta_proj(h) * self.gate_proj(h)
118
+ log_sigma = self.logsig_proj(h)
119
+
120
+ g_tok = self.guidance_proj(h).squeeze(-1)
121
+ g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
122
+
123
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)