AbstractPhil commited on
Commit
7b42604
Β·
1 Parent(s): 75ce7bd
Files changed (2) hide show
  1. app.py +68 -75
  2. two_stream_shunt_adapter.py +318 -110
app.py CHANGED
@@ -4,10 +4,10 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  from transformers import T5Tokenizer, T5EncoderModel
6
  from diffusers import DiffusionPipeline
7
- from safetensors.torch import load_file
8
  from huggingface_hub import hf_hub_download
9
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
10
- from configs import T5_SHUNT_REPOS
11
 
12
  # ─── Device & Model Setup ─────────────────────────────────────
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -31,12 +31,8 @@ config_l = T5_SHUNT_REPOS["clip_l"]["config"]
31
  config_g = T5_SHUNT_REPOS["clip_g"]["config"]
32
 
33
  # ─── Loader ───────────────────────────────────────────────────
34
- from safetensors.torch import safe_open
35
-
36
  def load_adapter(repo, filename, config):
37
  path = hf_hub_download(repo_id=repo, filename=filename)
38
-
39
- # Fallback-safe loading for ZeroGPU
40
  model = TwoStreamShuntAdapter(config).eval()
41
  tensors = {}
42
  with safe_open(path, framework="pt", device="cpu") as f:
@@ -46,76 +42,79 @@ def load_adapter(repo, filename, config):
46
  model.to(device)
47
  return model
48
 
49
-
50
- # ─── Visualization ────────────────────────────────────────────
51
- def plot_heat(mat, title):
52
- import io
53
- fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
54
- im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
55
- ax.set_title(title)
56
- plt.colorbar(im, ax=ax)
57
- buf = io.BytesIO()
58
- plt.savefig(buf, format="png", bbox_inches='tight')
59
- buf.seek(0)
60
- return buf
61
-
62
  # ─── Inference ────────────────────────────────────────────────
63
  @torch.no_grad()
64
  def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
 
 
 
 
 
 
 
 
 
 
 
 
65
  t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
66
- t5_seq = t5_mod(t5_ids).last_hidden_state
67
-
68
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l)
69
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g)
70
-
71
- clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device)
72
- clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device)
73
-
74
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
75
- gate_l_scaled = gate_l * gate_prob
76
- delta_l_final = delta_l * strength * gate_l_scaled
77
- clip_l_mod = clip_l_in + delta_l_final
78
- if use_anchor:
79
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
80
- if noise > 0:
81
- clip_l_mod += torch.randn_like(clip_l_mod) * noise
82
-
83
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
84
- gate_g_scaled = gate_g * gate_prob
85
- delta_g_final = delta_g * strength * gate_g_scaled
86
- clip_g_mod = clip_g_in + delta_g_final
87
- if use_anchor:
88
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
89
- if noise > 0:
90
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
91
-
92
- # Combine embeddings
93
- prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
94
- neg_embeds = torch.zeros_like(prompt_embeds)
95
-
96
- # Compute pooled embeds (mean pooling as default fallback)
97
- pooled_prompt_embeds = prompt_embeds.mean(dim=1)
98
- pooled_neg_embeds = neg_embeds.mean(dim=1)
99
-
100
- # SDXL generation with required pooled embeddings
 
 
 
 
 
 
 
 
 
 
 
 
101
  image = pipe(
102
- prompt_embeds=prompt_embeds,
103
- pooled_prompt_embeds=pooled_prompt_embeds,
104
- negative_prompt_embeds=neg_embeds,
105
- negative_pooled_prompt_embeds=pooled_neg_embeds,
106
  num_inference_steps=20,
107
  guidance_scale=5.0
108
  ).images[0]
109
 
110
- return (
111
- image,
112
- plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
113
- plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
114
- plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
115
- plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
116
- f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
117
- f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
118
- )
119
 
120
  # ─── Gradio App ───────────────────────────────────────────────
121
  with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
@@ -134,18 +133,12 @@ with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
134
 
135
  with gr.Column():
136
  out_img = gr.Image(label="Generated Image")
137
- delta_l = gr.Image(label="Ξ” CLIP-L")
138
- gate_l = gr.Image(label="Gate CLIP-L")
139
- delta_g = gr.Image(label="Ξ” CLIP-G")
140
- gate_g = gr.Image(label="Gate CLIP-G")
141
- stats_l = gr.Textbox(label="CLIP-L Stats")
142
- stats_g = gr.Textbox(label="CLIP-G Stats")
143
 
144
  run_btn.click(
145
  fn=infer,
146
  inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
147
- outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
148
  )
149
 
150
  if __name__ == "__main__":
151
- demo.launch()
 
4
  import matplotlib.pyplot as plt
5
  from transformers import T5Tokenizer, T5EncoderModel
6
  from diffusers import DiffusionPipeline
7
+ from safetensors.torch import safe_open
8
  from huggingface_hub import hf_hub_download
9
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
10
+ from adapter_config import T5_SHUNT_REPOS
11
 
12
  # ─── Device & Model Setup ─────────────────────────────────────
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
31
  config_g = T5_SHUNT_REPOS["clip_g"]["config"]
32
 
33
  # ─── Loader ───────────────────────────────────────────────────
 
 
34
  def load_adapter(repo, filename, config):
35
  path = hf_hub_download(repo_id=repo, filename=filename)
 
 
36
  model = TwoStreamShuntAdapter(config).eval()
37
  tensors = {}
38
  with safe_open(path, framework="pt", device="cpu") as f:
 
42
  model.to(device)
43
  return model
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # ─── Inference ────────────────────────────────────────────────
46
  @torch.no_grad()
47
  def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
48
+ adapter_list = []
49
+ # Load adapters with config
50
+ adapter_list.append({
51
+ "adapter": load_adapter(repo_l, adapter_l_file, config_l),
52
+ "config": config_l
53
+ })
54
+ adapter_list.append({
55
+ "adapter": load_adapter(repo_g, adapter_g_file, config_g),
56
+ "config": config_g
57
+ })
58
+
59
+ # Encode prompt via T5
60
  t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
61
+ t5_seq = t5_mod(t5_ids).last_hidden_state # (B, L, 768)
62
+
63
+ # Encode prompt via SDXL normally to get CLIP-L and CLIP-G outputs
64
+ prompt_embeds, pooled_prompt_embeds = pipe._encode_prompt(
65
+ prompt=prompt,
66
+ device=device,
67
+ num_images_per_prompt=1,
68
+ do_classifier_free_guidance=False,
69
+ )
70
+
71
+ total_dim = prompt_embeds.shape[-1]
72
+ cond_tensor = prompt_embeds.clone()
73
+
74
+ for adapter_info in adapter_list:
75
+ adapter_model = adapter_info["adapter"]
76
+ adapter_config = adapter_info["config"]
77
+ clip_dim = adapter_config["clip"]["hidden_size"]
78
+
79
+ if clip_dim == 768:
80
+ clip_slice = cond_tensor[:, :, :768]
81
+ slice_start, slice_end = 0, 768
82
+ elif clip_dim == 1280:
83
+ clip_slice = cond_tensor[:, :, 768:2048] if total_dim >= 2048 else cond_tensor[:, :, 768:]
84
+ slice_start, slice_end = 768, 2048
85
+ else:
86
+ continue
87
+
88
+ anchor, delta_mean_adapter, log_sigma_adapter, _, _, _, g_pred_adapter, gate_adapter = adapter_model(t5_seq, clip_slice)
89
+ gate = gate_adapter * gate_prob
90
+ delta = (delta_mean_adapter + 0.0) * strength * gate
91
+
92
+ if delta.shape[1] != clip_slice.shape[1]:
93
+ delta = torch.nn.functional.interpolate(
94
+ delta.transpose(1, 2),
95
+ size=clip_slice.size(1),
96
+ mode="nearest"
97
+ ).transpose(1, 2)
98
+
99
+ if use_anchor:
100
+ clip_slice = clip_slice * (1 - gate) + anchor * gate
101
+
102
+ if noise > 0:
103
+ clip_slice = clip_slice + torch.randn_like(clip_slice) * noise
104
+
105
+ cond_tensor[:, :, slice_start:slice_end] = (clip_slice + delta).type_as(cond_tensor)
106
+
107
+ pooled_embed = cond_tensor.mean(dim=1)
108
  image = pipe(
109
+ prompt_embeds=cond_tensor,
110
+ pooled_prompt_embeds=pooled_embed,
111
+ negative_prompt_embeds=torch.zeros_like(cond_tensor),
112
+ negative_pooled_prompt_embeds=torch.zeros_like(pooled_embed),
113
  num_inference_steps=20,
114
  guidance_scale=5.0
115
  ).images[0]
116
 
117
+ return image
 
 
 
 
 
 
 
 
118
 
119
  # ─── Gradio App ───────────────────────────────────────────────
120
  with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
 
133
 
134
  with gr.Column():
135
  out_img = gr.Image(label="Generated Image")
 
 
 
 
 
 
136
 
137
  run_btn.click(
138
  fn=infer,
139
  inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
140
+ outputs=out_img
141
  )
142
 
143
  if __name__ == "__main__":
144
+ demo.launch(share=True)
two_stream_shunt_adapter.py CHANGED
@@ -1,123 +1,331 @@
1
  import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
 
 
 
 
 
 
 
4
 
5
- # ─── Residual Pocket Block ───────────────────────────────────
6
- class BottleneckResBlock(nn.Module):
7
- def __init__(self, dim, kernel=3, dropout=0.1):
8
- super().__init__()
9
- self.norm = nn.LayerNorm(dim)
10
- self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
11
- self.proj = nn.Sequential(
12
- nn.Linear(dim, dim * 2),
13
- nn.GELU(),
14
- nn.Linear(dim * 2, dim),
15
- nn.Dropout(dropout)
16
- )
17
 
18
- def forward(self, x):
19
- residual = x
20
- x = self.norm(x)
21
- x = x.transpose(1, 2)
22
- x = self.conv(x).transpose(1, 2)
23
- return residual + self.proj(x)
24
 
25
- # ─── Two Stream Shunt Adapter ──────────────────────────────────────
26
- class TwoStreamShuntAdapter(nn.Module):
27
- def __init__(self, config: dict):
28
- super().__init__()
29
- self.config = config
30
- self.t5_dim = config["t5"]["hidden_size"]
31
- self.clip_dim = config["clip"]["hidden_size"]
32
- self.bneck = config["bottleneck"]
33
- self.heads = config["heads"]
34
- self.tau_init = config["tau_init"]
35
- self.max_guidance = config["max_guidance"]
36
 
37
- use_norm = config.get("layer_norm", True)
38
- use_do = config.get("use_dropout", True)
39
- do_p = config.get("dropout", 0.1)
40
- proj_depth = config.get("proj_layers", 2)
 
 
41
 
42
- def build_projection(input_dim, output_dim):
43
- layers = []
44
- last_dim = input_dim
45
- if use_norm:
46
- layers.append(nn.LayerNorm(last_dim))
47
- for i in range(proj_depth):
48
- next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
49
- layers.append(nn.Linear(last_dim, next_dim))
50
- layers.append(nn.GELU())
51
- if use_do:
52
- layers.append(nn.Dropout(do_p))
53
- last_dim = next_dim
54
- layers.append(nn.Linear(last_dim, output_dim))
55
- return nn.Sequential(*layers)
56
 
57
- # Projections
58
- self.proj_t5 = build_projection(self.t5_dim, self.bneck)
59
- self.proj_clip = build_projection(self.clip_dim, self.bneck)
60
 
61
- # Attention
62
- self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
63
- self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
64
- self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
 
 
 
 
 
 
 
65
 
66
- # Residual Pocket
67
- self.pocket_blocks = nn.Sequential(
68
- BottleneckResBlock(self.bneck, dropout=do_p),
69
- BottleneckResBlock(self.bneck, dropout=do_p)
70
- )
 
 
 
 
 
 
 
71
 
72
- # Fuse
73
- self.fuse = nn.Sequential(
74
- nn.LayerNorm(2 * self.bneck),
75
- nn.Linear(2 * self.bneck, self.bneck * 2),
76
- nn.GELU(),
77
- nn.Linear(self.bneck * 2, self.bneck)
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Output Projections
81
- self.anchor_proj = build_projection(self.bneck, self.clip_dim)
82
- self.delta_proj = build_projection(self.bneck, self.clip_dim)
83
- self.logsig_proj = build_projection(self.bneck, self.clip_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- self.gate_proj = nn.Sequential(
86
- nn.LayerNorm(self.bneck),
87
- nn.Linear(self.bneck, self.bneck),
88
- nn.GELU(),
89
- nn.Linear(self.bneck, 1),
90
- nn.Tanh(),
91
- nn.Sigmoid()
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- self.guidance_proj = nn.Sequential(
95
- nn.LayerNorm(self.bneck),
96
- nn.Linear(self.bneck, 1),
97
- nn.Sigmoid()
98
- )
99
-
100
- def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
101
- if self.config.get("assert_input_dims", True):
102
- assert t5_seq.size(-1) == self.t5_dim
103
- assert clip_seq.size(-1) == self.clip_dim
104
-
105
- t5_b = self.proj_t5(t5_seq)
106
- clip_b = self.proj_clip(clip_seq)
107
-
108
- t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
109
- c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
110
-
111
- pocket = self.pocket_blocks(t2c)
112
-
113
- pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
114
- h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
115
-
116
- anchor = self.anchor_proj(h)
117
- delta = self.delta_proj(h) * self.gate_proj(h)
118
- log_sigma = self.logsig_proj(h)
119
-
120
- g_tok = self.guidance_proj(h).squeeze(-1)
121
- g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
122
-
123
- return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)
 
1
  import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from transformers import T5Tokenizer, T5EncoderModel
6
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
7
+ from safetensors.torch import load_file
8
+ from huggingface_hub import hf_hub_download
9
+ from two_stream_shunt_adapter import TwoStreamShuntAdapter
10
+ from configs import T5_SHUNT_REPOS
11
 
12
+ # ─── Device & Model Setup ─────────────────────────────────────
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
15
 
16
+ # T5 Model for semantic understanding
17
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
18
+ t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
 
 
 
19
 
20
+ # SDXL Pipeline with proper text encoders
21
+ pipe = StableDiffusionXLPipeline.from_pretrained(
22
+ "stabilityai/stable-diffusion-xl-base-1.0",
23
+ torch_dtype=dtype,
24
+ variant="fp16" if dtype == torch.float16 else None,
25
+ use_safetensors=True
26
+ ).to(device)
 
 
 
 
27
 
28
+ # Available schedulers
29
+ SCHEDULERS = {
30
+ "DPM++ 2M": DPMSolverMultistepScheduler,
31
+ "DDIM": DDIMScheduler,
32
+ "Euler": EulerDiscreteScheduler,
33
+ }
34
 
35
+ # ─── Adapter Configs ──────────────────────────────────────────
36
+ clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
37
+ clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
38
+ repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
39
+ repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
40
+ config_l = T5_SHUNT_REPOS["clip_l"]["config"]
41
+ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
 
 
 
 
 
 
 
42
 
43
+ # ─── Loader ───────────────────────────────────────────────────
44
+ from safetensors.torch import safe_open
 
45
 
46
+ def load_adapter(repo, filename, config):
47
+ path = hf_hub_download(repo_id=repo, filename=filename)
48
+
49
+ model = TwoStreamShuntAdapter(config).eval()
50
+ tensors = {}
51
+ with safe_open(path, framework="pt", device="cpu") as f:
52
+ for key in f.keys():
53
+ tensors[key] = f.get_tensor(key)
54
+ model.load_state_dict(tensors)
55
+ model.to(device)
56
+ return model
57
 
58
+ # ─── Visualization ────────────────────────────────────────────
59
+ def plot_heat(mat, title):
60
+ import io
61
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
62
+ im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
63
+ ax.set_title(title)
64
+ plt.colorbar(im, ax=ax)
65
+ buf = io.BytesIO()
66
+ plt.savefig(buf, format="png", bbox_inches='tight')
67
+ buf.seek(0)
68
+ plt.close(fig)
69
+ return buf
70
 
71
+ # ─── SDXL Text Encoding ───────────────────────────────────────
72
+ def encode_sdxl_prompt(prompt, negative_prompt=""):
73
+ """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
74
+
75
+ # Tokenize for both encoders
76
+ tokens_l = pipe.tokenizer(
77
+ prompt,
78
+ padding="max_length",
79
+ max_length=77,
80
+ truncation=True,
81
+ return_tensors="pt"
82
+ ).input_ids.to(device)
83
+
84
+ tokens_g = pipe.tokenizer_2(
85
+ prompt,
86
+ padding="max_length",
87
+ max_length=77,
88
+ truncation=True,
89
+ return_tensors="pt"
90
+ ).input_ids.to(device)
91
+
92
+ # Negative prompts
93
+ neg_tokens_l = pipe.tokenizer(
94
+ negative_prompt,
95
+ padding="max_length",
96
+ max_length=77,
97
+ truncation=True,
98
+ return_tensors="pt"
99
+ ).input_ids.to(device)
100
+
101
+ neg_tokens_g = pipe.tokenizer_2(
102
+ negative_prompt,
103
+ padding="max_length",
104
+ max_length=77,
105
+ truncation=True,
106
+ return_tensors="pt"
107
+ ).input_ids.to(device)
108
+
109
+ with torch.no_grad():
110
+ # CLIP-L embeddings (768d)
111
+ clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
+ neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
+
114
+ # CLIP-G embeddings (1280d)
115
+ clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
+ neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
+
118
+ # Pooled embeddings for SDXL
119
+ pooled_embeds = pipe.text_encoder_2(tokens_g)[1]
120
+ neg_pooled_embeds = pipe.text_encoder_2(neg_tokens_g)[1]
121
+
122
+ return {
123
+ "clip_l": clip_l_embeds,
124
+ "clip_g": clip_g_embeds,
125
+ "neg_clip_l": neg_clip_l_embeds,
126
+ "neg_clip_g": neg_clip_g_embeds,
127
+ "pooled": pooled_embeds,
128
+ "neg_pooled": neg_pooled_embeds
129
+ }
130
 
131
+ # ─── Inference ────────────────────────────────────────────────
132
+ @torch.no_grad()
133
+ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
+
136
+ # Set seed for reproducibility
137
+ if seed != -1:
138
+ torch.manual_seed(seed)
139
+ np.random.seed(seed)
140
+
141
+ # Set scheduler
142
+ if scheduler_name in SCHEDULERS:
143
+ pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
+
145
+ # Get T5 embeddings for semantic understanding
146
+ t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
147
+ t5_seq = t5_mod(t5_ids).last_hidden_state
148
+
149
+ # Get proper SDXL CLIP embeddings
150
+ clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
151
+
152
+ # Load adapters
153
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
154
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
155
+
156
+ # Apply CLIP-L adapter
157
+ if adapter_l is not None:
158
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
159
+ gate_l_scaled = gate_l * gate_prob
160
+ delta_l_final = delta_l * strength * gate_l_scaled
161
+ clip_l_mod = clip_embeds["clip_l"] + delta_l_final
162
+ if use_anchor:
163
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
164
+ if noise > 0:
165
+ clip_l_mod += torch.randn_like(clip_l_mod) * noise
166
+ else:
167
+ clip_l_mod = clip_embeds["clip_l"]
168
+ delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
169
+ gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
170
+ g_pred_l = torch.tensor(0.0)
171
+ tau_l = torch.tensor(0.0)
172
+
173
+ # Apply CLIP-G adapter
174
+ if adapter_g is not None:
175
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
176
+ gate_g_scaled = gate_g * gate_prob
177
+ delta_g_final = delta_g * strength * gate_g_scaled
178
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final
179
+ if use_anchor:
180
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
181
+ if noise > 0:
182
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
183
+ else:
184
+ clip_g_mod = clip_embeds["clip_g"]
185
+ delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
186
+ gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
187
+ g_pred_g = torch.tensor(0.0)
188
+ tau_g = torch.tensor(0.0)
189
+
190
+ # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
191
+ prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
192
+ neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
193
+
194
+ # Generate image with proper SDXL parameters
195
+ image = pipe(
196
+ prompt_embeds=prompt_embeds,
197
+ pooled_prompt_embeds=clip_embeds["pooled"],
198
+ negative_prompt_embeds=neg_embeds,
199
+ negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
200
+ num_inference_steps=steps,
201
+ guidance_scale=cfg_scale,
202
+ width=width,
203
+ height=height,
204
+ generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
205
+ ).images[0]
206
+
207
+ return (
208
+ image,
209
+ plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
210
+ plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
211
+ plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
212
+ plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
213
+ f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
214
+ f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
215
+ )
216
 
217
+ # ─── Gradio Interface ─────────────────────────────────────────
218
+ with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
220
+ gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings")
221
+
222
+ with gr.Row():
223
+ with gr.Column(scale=1):
224
+ # Prompts
225
+ with gr.Group():
226
+ gr.Markdown("### Prompts")
227
+ prompt = gr.Textbox(
228
+ label="Prompt",
229
+ value="a futuristic control station with holographic displays",
230
+ lines=3
231
+ )
232
+ negative_prompt = gr.Textbox(
233
+ label="Negative Prompt",
234
+ value="blurry, low quality, distorted",
235
+ lines=2
236
+ )
237
+
238
+ # Adapters
239
+ with gr.Group():
240
+ gr.Markdown("### Adapters")
241
+ adapter_l = gr.Dropdown(
242
+ choices=["None"] + clip_l_opts,
243
+ label="CLIP-L (768d) Adapter",
244
+ value="None"
245
+ )
246
+ adapter_g = gr.Dropdown(
247
+ choices=["None"] + clip_g_opts,
248
+ label="CLIP-G (1280d) Adapter",
249
+ value="None"
250
+ )
251
+
252
+ # Adapter Controls
253
+ with gr.Group():
254
+ gr.Markdown("### Adapter Controls")
255
+ strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
256
+ noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
257
+ gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
258
+ use_anchor = gr.Checkbox(label="Use Anchor", value=True)
259
+
260
+ # Generation Settings
261
+ with gr.Group():
262
+ gr.Markdown("### Generation Settings")
263
+ with gr.Row():
264
+ steps = gr.Slider(1, 100, value=25, step=1, label="Steps")
265
+ cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale")
266
+
267
+ scheduler_name = gr.Dropdown(
268
+ choices=list(SCHEDULERS.keys()),
269
+ value="DPM++ 2M",
270
+ label="Scheduler"
271
+ )
272
+
273
+ with gr.Row():
274
+ width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
275
+ height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
276
+
277
+ seed = gr.Number(value=-1, label="Seed (-1 for random)")
278
+
279
+ run_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
280
+
281
+ with gr.Column(scale=1):
282
+ # Output
283
+ with gr.Group():
284
+ gr.Markdown("### Generated Image")
285
+ out_img = gr.Image(label="Result", height=400)
286
+
287
+ # Visualizations
288
+ with gr.Group():
289
+ gr.Markdown("### Adapter Visualizations")
290
+ with gr.Row():
291
+ delta_l = gr.Image(label="Ξ” CLIP-L", height=200)
292
+ gate_l = gr.Image(label="Gate CLIP-L", height=200)
293
+ with gr.Row():
294
+ delta_g = gr.Image(label="Ξ” CLIP-G", height=200)
295
+ gate_g = gr.Image(label="Gate CLIP-G", height=200)
296
+
297
+ # Stats
298
+ with gr.Group():
299
+ gr.Markdown("### Adapter Statistics")
300
+ stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False)
301
+ stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
302
+
303
+ # Event handlers
304
+ def process_adapters(adapter_l_val, adapter_g_val):
305
+ # Convert "None" back to None for processing
306
+ adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
307
+ adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
308
+ return adapter_l_processed, adapter_g_processed
309
+
310
+ def run_inference(*args):
311
+ # Process adapter selections
312
+ adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
313
+
314
+ # Call inference with processed adapters
315
+ new_args = list(args)
316
+ new_args[2] = adapter_l_processed
317
+ new_args[3] = adapter_g_processed
318
+
319
+ return infer(*new_args)
320
+
321
+ run_btn.click(
322
+ fn=run_inference,
323
+ inputs=[
324
+ prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
325
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
326
+ ],
327
+ outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
328
+ )
329
 
330
+ if __name__ == "__main__":
331
+ demo.launch(share=True)