File size: 14,298 Bytes
12aa86c
ca066a9
403ae01
 
ca066a9
c34205b
 
ca066a9
1e5ce4d
 
ca066a9
620a643
1e5ce4d
ca066a9
 
 
 
 
1e5ce4d
ca066a9
 
 
1e5ce4d
 
ca066a9
 
1e5ce4d
 
ca066a9
 
1e5ce4d
 
 
 
 
 
 
ca066a9
 
 
 
 
 
 
 
 
1e5ce4d
 
ca066a9
 
1e5ce4d
7229198
 
 
 
 
 
 
ca066a9
 
1e5ce4d
 
 
 
 
 
 
 
 
 
 
c34205b
7b42604
1e5ce4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db851e8
1e5ce4d
 
 
25bf19b
db851e8
25bf19b
db851e8
 
25bf19b
1e5ce4d
 
25bf19b
 
1e5ce4d
 
 
 
 
 
 
 
 
7b42604
1e5ce4d
12aa86c
6f70ac0
1e5ce4d
 
 
 
 
 
 
 
 
 
 
 
ce712b4
 
 
 
 
 
 
 
1e5ce4d
 
 
 
 
ce712b4
 
 
 
 
1e5ce4d
 
 
 
 
 
acd9841
1e5ce4d
 
 
7b42604
1e5ce4d
7b42604
1e5ce4d
 
 
 
 
 
 
 
 
 
acd9841
1e5ce4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403ae01
1e5ce4d
 
 
 
 
 
 
 
25bf19b
1e5ce4d
403ae01
1e5ce4d
 
 
 
 
 
 
 
 
 
403ae01
1e5ce4d
 
 
 
 
ca066a9
1e5ce4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca066a9
1e5ce4d
 
 
 
 
c34205b
403ae01
 
 
c34205b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import spaces
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import spaces
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from configs import T5_SHUNT_REPOS

# ─── Device & Model Setup ─────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# T5 Model for semantic understanding
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()

# SDXL Pipeline with proper text encoders
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=dtype,
    variant="fp16" if dtype == torch.float16 else None,
    use_safetensors=True
).to(device)

# Available schedulers
SCHEDULERS = {
    "DPM++ 2M": DPMSolverMultistepScheduler,
    "DDIM": DDIMScheduler,
    "Euler": EulerDiscreteScheduler,
}

# ─── Adapter Configs ──────────────────────────────────────────
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]

# ─── Loader ───────────────────────────────────────────────────
from safetensors.torch import safe_open

def load_adapter(repo, filename, config):
    path = hf_hub_download(repo_id=repo, filename=filename)
    
    model = TwoStreamShuntAdapter(config).eval()
    tensors = {}
    with safe_open(path, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    model.load_state_dict(tensors)
    model.to(device)
    return model

# ─── Visualization ────────────────────────────────────────────
def plot_heat(mat, title):
    import io
    fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
    im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
    ax.set_title(title)
    plt.colorbar(im, ax=ax)
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches='tight')
    buf.seek(0)
    plt.close(fig)
    return buf

# ─── SDXL Text Encoding ───────────────────────────────────────
def encode_sdxl_prompt(prompt, negative_prompt=""):
    """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
    
    # Tokenize for both encoders
    tokens_l = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)
    
    tokens_g = pipe.tokenizer_2(
        prompt,
        padding="max_length", 
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)
    
    # Negative prompts
    neg_tokens_l = pipe.tokenizer(
        negative_prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)
    
    neg_tokens_g = pipe.tokenizer_2(
        negative_prompt,
        padding="max_length",
        max_length=77, 
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)
    
    with torch.no_grad():
        # CLIP-L embeddings (768d) - works fine
        clip_l_embeds = pipe.text_encoder(tokens_l)[0]
        neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
        
        # CLIP-G embeddings (1280d) - [0] is pooled, [1] is sequence (opposite of CLIP-L)
        clip_g_output = pipe.text_encoder_2(tokens_g)
        clip_g_embeds = clip_g_output[1]  # sequence embeddings
        
        neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
        neg_clip_g_embeds = neg_clip_g_output[1]  # sequence embeddings
        
        # Pooled embeddings for SDXL
        pooled_embeds = clip_g_output[0]  # pooled embeddings
        neg_pooled_embeds = neg_clip_g_output[0]  # pooled embeddings
    
    return {
        "clip_l": clip_l_embeds,
        "clip_g": clip_g_embeds,
        "neg_clip_l": neg_clip_l_embeds,
        "neg_clip_g": neg_clip_g_embeds,
        "pooled": pooled_embeds,
        "neg_pooled": neg_pooled_embeds
    }

# ─── Inference ────────────────────────────────────────────────
@spaces.GPU
@torch.no_grad()
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, 
          use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
    
    # Set seed for reproducibility
    if seed != -1:
        torch.manual_seed(seed)
        np.random.seed(seed)
    
    # Set scheduler
    if scheduler_name in SCHEDULERS:
        pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
    
    # Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
    t5_ids = t5_tok(
        prompt, 
        return_tensors="pt", 
        padding="max_length", 
        max_length=77,
        truncation=True
    ).input_ids.to(device)
    t5_seq = t5_mod(t5_ids).last_hidden_state
    
    # Get proper SDXL CLIP embeddings
    clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
    
    # Debug shapes
    print(f"T5 seq shape: {t5_seq.shape}")
    print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
    print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
    
    # Load adapters
    adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
    adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
    
    # Apply CLIP-L adapter
    if adapter_l is not None:
        anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
        gate_l_scaled = gate_l * gate_prob
        delta_l_final = delta_l * strength * gate_l_scaled
        clip_l_mod = clip_embeds["clip_l"] + delta_l_final
        if use_anchor:
            clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
        if noise > 0:
            clip_l_mod += torch.randn_like(clip_l_mod) * noise
    else:
        clip_l_mod = clip_embeds["clip_l"]
        delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
        gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
        g_pred_l = torch.tensor(0.0)
        tau_l = torch.tensor(0.0)
    
    # Apply CLIP-G adapter  
    if adapter_g is not None:
        anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
        gate_g_scaled = gate_g * gate_prob
        delta_g_final = delta_g * strength * gate_g_scaled
        clip_g_mod = clip_embeds["clip_g"] + delta_g_final
        if use_anchor:
            clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
        if noise > 0:
            clip_g_mod += torch.randn_like(clip_g_mod) * noise
    else:
        clip_g_mod = clip_embeds["clip_g"]
        delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
        gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
        g_pred_g = torch.tensor(0.0)
        tau_g = torch.tensor(0.0)
    
    # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
    prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
    neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
    
    # Generate image with proper SDXL parameters
    image = pipe(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=clip_embeds["pooled"],
        negative_prompt_embeds=neg_embeds,
        negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        num_images_per_prompt=1,  # Explicitly set this
        generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
    ).images[0]
    
    return (
        image,
        plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
        plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), 
        plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
        plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
        f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
        f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
    )

# ─── Gradio Interface ─────────────────────────────────────────
with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
    gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Prompts
            with gr.Group():
                gr.Markdown("### Prompts")
                prompt = gr.Textbox(
                    label="Prompt", 
                    value="a futuristic control station with holographic displays",
                    lines=3
                )
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    value="blurry, low quality, distorted",
                    lines=2
                )
            
            # Adapters
            with gr.Group():
                gr.Markdown("### Adapters")
                adapter_l = gr.Dropdown(
                    choices=["None"] + clip_l_opts, 
                    label="CLIP-L (768d) Adapter",
                    value="None"
                )
                adapter_g = gr.Dropdown(
                    choices=["None"] + clip_g_opts, 
                    label="CLIP-G (1280d) Adapter", 
                    value="None"
                )
            
            # Adapter Controls
            with gr.Group():
                gr.Markdown("### Adapter Controls")
                strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
                noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
                gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
                use_anchor = gr.Checkbox(label="Use Anchor", value=True)
            
            # Generation Settings
            with gr.Group():
                gr.Markdown("### Generation Settings")
                with gr.Row():
                    steps = gr.Slider(1, 100, value=25, step=1, label="Steps")
                    cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale")
                
                scheduler_name = gr.Dropdown(
                    choices=list(SCHEDULERS.keys()),
                    value="DPM++ 2M",
                    label="Scheduler"
                )
                
                with gr.Row():
                    width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
                    height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
                
                seed = gr.Number(value=-1, label="Seed (-1 for random)")
            
            run_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            # Output
            with gr.Group():
                gr.Markdown("### Generated Image")
                out_img = gr.Image(label="Result", height=400)
            
            # Visualizations
            with gr.Group():
                gr.Markdown("### Adapter Visualizations")
                with gr.Row():
                    delta_l = gr.Image(label="Ξ” CLIP-L", height=200)
                    gate_l = gr.Image(label="Gate CLIP-L", height=200)
                with gr.Row():
                    delta_g = gr.Image(label="Ξ” CLIP-G", height=200) 
                    gate_g = gr.Image(label="Gate CLIP-G", height=200)
            
            # Stats
            with gr.Group():
                gr.Markdown("### Adapter Statistics")
                stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False)
                stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
    
    # Event handlers
    def process_adapters(adapter_l_val, adapter_g_val):
        # Convert "None" back to None for processing
        adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
        adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
        return adapter_l_processed, adapter_g_processed
    
    def run_inference(*args):
        # Process adapter selections
        adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
        
        # Call inference with processed adapters
        new_args = list(args)
        new_args[2] = adapter_l_processed
        new_args[3] = adapter_g_processed
        
        return infer(*new_args)
    
    run_btn.click(
        fn=run_inference,
        inputs=[
            prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, 
            use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
        ],
        outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
    )

if __name__ == "__main__":
    demo.launch()