File size: 19,255 Bytes
12aa86c
ca066a9
403ae01
 
ca066a9
c34205b
ca066a9
1e5ce4d
 
ca066a9
620a643
1e5ce4d
b6b9cb1
ca066a9
5759aab
a4e1cd2
 
 
ca066a9
1e5ce4d
 
 
 
 
 
 
ca066a9
 
 
 
 
 
 
 
5759aab
 
b6b9cb1
5759aab
ca066a9
1e5ce4d
7229198
 
c22af2e
7229198
 
 
5759aab
ca066a9
1e5ce4d
5759aab
b6b9cb1
 
 
5759aab
 
 
b6b9cb1
5759aab
 
b6b9cb1
 
 
 
 
5759aab
b6b9cb1
5759aab
 
b6b9cb1
 
 
 
 
 
 
 
5759aab
b6b9cb1
1e5ce4d
5759aab
1e5ce4d
5759aab
b6b9cb1
 
 
 
7b42604
5759aab
 
1e5ce4d
 
 
5759aab
1e5ce4d
 
 
5759aab
1e5ce4d
 
 
5759aab
1e5ce4d
 
 
5759aab
1e5ce4d
 
 
5759aab
b6b9cb1
 
 
 
 
1e5ce4d
b6b9cb1
 
25bf19b
5759aab
db851e8
b6b9cb1
5759aab
 
1e5ce4d
 
 
 
 
 
 
 
 
7b42604
5759aab
dfcfa0d
377ff40
 
 
c22af2e
a4e1cd2
 
 
c22af2e
5759aab
c22af2e
 
 
 
 
 
 
 
 
 
 
 
5759aab
c22af2e
 
 
b6b9cb1
 
 
c22af2e
 
 
 
 
5759aab
c22af2e
5759aab
c22af2e
b6b9cb1
 
 
c22af2e
5759aab
c22af2e
 
5759aab
d657c76
65d629b
9da4d4e
65d629b
5759aab
 
c22af2e
 
 
b6b9cb1
 
 
 
 
 
 
 
 
 
 
377ff40
b6b9cb1
 
 
 
377ff40
 
 
 
 
 
b6b9cb1
 
 
 
 
 
 
 
 
377ff40
 
 
 
 
b6b9cb1
 
 
 
377ff40
b6b9cb1
 
c22af2e
 
 
 
b6b9cb1
 
c22af2e
 
 
b6b9cb1
 
 
 
 
 
 
 
 
 
 
377ff40
b6b9cb1
 
 
 
377ff40
 
 
 
 
 
b6b9cb1
 
 
 
 
 
 
 
 
377ff40
 
 
 
 
b6b9cb1
 
 
 
377ff40
b6b9cb1
 
c22af2e
 
 
 
b6b9cb1
 
c22af2e
5759aab
 
 
c22af2e
5759aab
c22af2e
 
 
 
 
 
 
 
 
5759aab
b6b9cb1
c22af2e
 
5759aab
b6b9cb1
 
 
 
5759aab
 
b6b9cb1
 
5759aab
 
403ae01
1e5ce4d
5759aab
 
 
 
 
 
 
 
 
1e5ce4d
5759aab
1e5ce4d
5759aab
 
1e5ce4d
 
 
 
5759aab
 
1e5ce4d
5759aab
 
 
1e5ce4d
5759aab
1e5ce4d
2645d7e
5759aab
1e5ce4d
 
5759aab
 
2645d7e
5759aab
1e5ce4d
5759aab
 
 
515374c
 
 
 
 
35ee1e9
5759aab
 
 
 
1e5ce4d
515374c
 
1e5ce4d
 
 
 
 
 
 
 
 
 
 
b6b9cb1
5759aab
 
1e5ce4d
5759aab
 
 
 
 
 
 
1e5ce4d
5759aab
 
1e5ce4d
5759aab
 
 
 
 
 
 
1e5ce4d
5759aab
 
 
 
 
 
 
1e5ce4d
5759aab
 
 
377ff40
 
 
5759aab
 
 
1e5ce4d
5759aab
403ae01
5759aab
403ae01
5759aab
c34205b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import spaces
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from configs import T5_SHUNT_REPOS
import io

# ─── Global Variables ─────────────────────────────────────────
t5_tok = None
t5_mod = None 
pipe = None

# Available schedulers
SCHEDULERS = {
    "DPM++ 2M": DPMSolverMultistepScheduler,
    "DDIM": DDIMScheduler,
    "Euler": EulerDiscreteScheduler,
}

# ─── Adapter Configs ──────────────────────────────────────────
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]

# ─── Helper Functions ─────────────────────────────────────────
def load_adapter(repo, filename, config, device):
    """Load adapter from safetensors file"""
    from safetensors.torch import safe_open
    path = hf_hub_download(repo_id=repo, filename=filename)
    
    model = TwoStreamShuntAdapter(config).eval()
    tensors = {}
    with safe_open(path, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    model.load_state_dict(tensors)
    return model.to(device)

def plot_heat(mat, title):
    """Create heatmap visualization with proper shape handling"""
    # Handle different input shapes
    if isinstance(mat, torch.Tensor):
        mat = mat.detach().cpu().numpy()
    
    # Ensure we have a 2D array for visualization
    if len(mat.shape) == 1:
        # 1D array - reshape to single row
        mat = mat.reshape(1, -1)
    elif len(mat.shape) == 3:
        # 3D array - average over batch dimension
        if mat.shape[0] == 1:
            mat = mat.squeeze(0)
        else:
            mat = mat.mean(axis=0)
    elif len(mat.shape) > 3:
        # Flatten higher dimensions
        mat = mat.reshape(-1, mat.shape[-1])
    
    # Create figure with proper DPI
    plt.figure(figsize=(8, 4), dpi=100)
    plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest')
    plt.title(title, fontsize=12, fontweight='bold')
    plt.xlabel("Token Position")
    plt.ylabel("Feature Dimension")
    plt.colorbar(shrink=0.8)
    plt.tight_layout()
    
    # Convert to PIL Image
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
    buf.seek(0)
    pil_image = Image.open(buf)
    plt.close()
    
    # Convert to numpy array for Gradio
    return np.array(pil_image)

def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
    """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
    
    # Tokenize for both encoders
    tokens_l = pipe.tokenizer(
        prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
    ).input_ids.to(device)
    
    tokens_g = pipe.tokenizer_2(
        prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
    ).input_ids.to(device)
    
    neg_tokens_l = pipe.tokenizer(
        negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
    ).input_ids.to(device)
    
    neg_tokens_g = pipe.tokenizer_2(
        negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
    ).input_ids.to(device)
    
    with torch.no_grad():
        # CLIP-L: [0] = sequence, [1] = pooled
        clip_l_output = pipe.text_encoder(tokens_l, output_hidden_states=False)
        clip_l_embeds = clip_l_output[0]
        
        neg_clip_l_output = pipe.text_encoder(neg_tokens_l, output_hidden_states=False)
        neg_clip_l_embeds = neg_clip_l_output[0]
        
        # CLIP-G: [0] = pooled, [1] = sequence
        clip_g_output = pipe.text_encoder_2(tokens_g, output_hidden_states=False)
        clip_g_embeds = clip_g_output[1]  # sequence embeddings
        pooled_embeds = clip_g_output[0]  # pooled embeddings
        
        neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g, output_hidden_states=False)
        neg_clip_g_embeds = neg_clip_g_output[1]
        neg_pooled_embeds = neg_clip_g_output[0]
    
    return {
        "clip_l": clip_l_embeds,
        "clip_g": clip_g_embeds,
        "neg_clip_l": neg_clip_l_embeds,
        "neg_clip_g": neg_clip_g_embeds,
        "pooled": pooled_embeds,
        "neg_pooled": neg_pooled_embeds
    }

# ─── Main Inference Function ──────────────────────────────────
@spaces.GPU
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, delta_scale, 
          sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, 
          scheduler_name, width, height, seed):
    
    global t5_tok, t5_mod, pipe
    device = torch.device("cuda")
    dtype = torch.float16
    
    # Initialize models
    if t5_tok is None:
        t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
        t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
        
    if pipe is None:
        pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=dtype,
            variant="fp16",
            use_safetensors=True
        ).to(device)
    
    # Set seed
    if seed != -1:
        torch.manual_seed(seed)
        np.random.seed(seed)
        generator = torch.Generator(device=device).manual_seed(seed)
    else:
        generator = None
    
    # Set scheduler
    if scheduler_name in SCHEDULERS:
        pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
    
    # Get T5 embeddings
    t5_ids = t5_tok(
        prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
    ).input_ids.to(device)
    
    with torch.no_grad():
        t5_seq = t5_mod(t5_ids).last_hidden_state
    
    # Get CLIP embeddings
    clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
    
    # Load and apply adapters
    if(adapter_l_file == "t5-vit-l-14-dual_shunt_booru_13_000_000.safetensors" or adapter_l_file == "t5-vit-l-14-dual_shunt_booru_51_200_000.safetensors"):
        config_l["heads"] = 4
    else:
        config_l["heads"] = 12
    adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None
    adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None
    
    # Apply CLIP-L adapter
    if adapter_l is not None:
        with torch.no_grad():
            # Run adapter forward pass
            adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float())
            
            # Unpack outputs (ensure correct number of outputs)
            if len(adapter_output) == 8:
                anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output
            else:
                # Handle different return formats
                anchor_l = adapter_output[0]
                delta_l = adapter_output[1]
                log_sigma_l = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_l)
                gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l)
                tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
                g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
            
            # Scale delta values
            delta_l = delta_l * delta_scale
            
            # Apply g_pred scaling to gate
            gate_l = gate_l * g_pred_l * gpred_scale
            
            # Apply gate scaling
            gate_l_scaled = torch.sigmoid(gate_l) * gate_prob
            
            # Compute final delta with strength and gate
            delta_l_final = delta_l * strength * gate_l_scaled
            
            # Apply delta to embeddings
            clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
            
            # Apply sigma-based noise if specified
            if sigma_scale > 0:
                sigma_l = torch.exp(log_sigma_l * sigma_scale)
                clip_l_mod += torch.randn_like(clip_l_mod) * sigma_l.to(dtype)
            
            # Apply anchor mixing if enabled
            if use_anchor:
                clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
            
            # Add additional noise if specified
            if noise > 0:
                clip_l_mod += torch.randn_like(clip_l_mod) * noise
    else:
        clip_l_mod = clip_embeds["clip_l"]
        delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
        gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
        g_pred_l = torch.tensor(0.0)
        tau_l = torch.tensor(0.0)
    
    # Apply CLIP-G adapter  
    if adapter_g is not None:
        with torch.no_grad():
            # Run adapter forward pass
            adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float())
            
            # Unpack outputs (ensure correct number of outputs)
            if len(adapter_output) == 8:
                anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output
            else:
                # Handle different return formats
                anchor_g = adapter_output[0]
                delta_g = adapter_output[1]
                log_sigma_g = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_g)
                gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g)
                tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
                g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
            
            # Scale delta values
            delta_g = delta_g * delta_scale
            
            # Apply g_pred scaling to gate
            gate_g = gate_g * g_pred_g * gpred_scale
            
            # Apply gate scaling
            gate_g_scaled = torch.sigmoid(gate_g) * gate_prob
            
            # Compute final delta with strength and gate
            delta_g_final = delta_g * strength * gate_g_scaled
            
            # Apply delta to embeddings
            clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
            
            # Apply sigma-based noise if specified
            if sigma_scale > 0:
                sigma_g = torch.exp(log_sigma_g * sigma_scale)
                clip_g_mod += torch.randn_like(clip_g_mod) * sigma_g.to(dtype)
            
            # Apply anchor mixing if enabled
            if use_anchor:
                clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
            
            # Add additional noise if specified
            if noise > 0:
                clip_g_mod += torch.randn_like(clip_g_mod) * noise
    else:
        clip_g_mod = clip_embeds["clip_g"]
        delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
        gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
        g_pred_g = torch.tensor(0.0)
        tau_g = torch.tensor(0.0)
    
    # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
    prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
    neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1)
    
    # Generate image
    image = pipe(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=clip_embeds["pooled"],
        negative_prompt_embeds=neg_embeds,
        negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        num_images_per_prompt=1,
        generator=generator
    ).images[0]
    
    # Create visualizations
    delta_l_viz = plot_heat(delta_l_final.squeeze(), "CLIP-L Delta Values")
    gate_l_viz = plot_heat(gate_l_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-L Gate Activations")
    delta_g_viz = plot_heat(delta_g_final.squeeze(), "CLIP-G Delta Values") 
    gate_g_viz = plot_heat(gate_g_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-G Gate Activations")
    
    # Statistics
    stats_l = f"g_pred_l: {float(g_pred_l.mean().item() if hasattr(g_pred_l, 'mean') else g_pred_l):.3f}, Ο„_l: {float(tau_l.mean().item() if hasattr(tau_l, 'mean') else tau_l):.3f}"
    stats_g = f"g_pred_g: {float(g_pred_g.mean().item() if hasattr(g_pred_g, 'mean') else g_pred_g):.3f}, Ο„_g: {float(tau_g.mean().item() if hasattr(tau_g, 'mean') else tau_g):.3f}"
    
    return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g

# ─── Gradio Interface ─────────────────────────────────────────
def create_interface():
    with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🧠 SDXL Dual Shunt Adapter")
        gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*")
        
        with gr.Row():
            with gr.Column(scale=1):
                # Prompts
                gr.Markdown("### πŸ“ Prompts")
                prompt = gr.Textbox(
                    label="Prompt",
                    value="a futuristic control station with holographic displays",
                    lines=3,
                    placeholder="Describe what you want to generate..."
                )
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    value="blurry, low quality, distorted",
                    lines=2,
                    placeholder="Describe what you want to avoid..."
                )
                
                # Adapters
                gr.Markdown("### βš™οΈ Adapters")
                adapter_l = gr.Dropdown(
                    choices=["None"] + clip_l_opts,
                    label="CLIP-L (768d) Adapter",
                    value="t5-vit-l-14-dual_shunt_caption.safetensors",
                    info="Choose adapter for CLIP-L embeddings"
                )
                adapter_g = gr.Dropdown(
                    choices=["None"] + clip_g_opts,
                    label="CLIP-G (1280d) Adapter",
                    value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
                    info="Choose adapter for CLIP-G embeddings"
                )
                
                # Controls
                gr.Markdown("### πŸŽ›οΈ Adapter Controls")
                strength = gr.Slider(0.0, 10.0, value=4.0, step=0.01, label="Adapter Strength")
                delta_scale = gr.Slider(-15.0, 15.0, value=0.2, step=0.1, label="Delta Scale", info="Scales the delta values, recommended 1")
                sigma_scale = gr.Slider(0, 15.0, value=0.1, step=0.1, label="Sigma Scale", info="Scales the noise variance, recommended 1")
                gpred_scale = gr.Slider(0.0, 20.0, value=2.0, step=0.01, label="G-Pred Scale", info="Scales the gate prediction, recommended 2")
                noise = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Noise Injection")
                gate_prob = gr.Slider(0.0, 1.0, value=0.27, step=0.01, label="Gate Probability")
                use_anchor = gr.Checkbox(label="Use Anchor Points", value=True)
                
                # Generation Settings
                gr.Markdown("### 🎨 Generation Settings")
                with gr.Row():
                    steps = gr.Slider(1, 50, value=20, step=1, label="Steps")
                    cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="CFG Scale")
                
                scheduler_name = gr.Dropdown(
                    choices=list(SCHEDULERS.keys()),
                    value="DPM++ 2M",
                    label="Scheduler"
                )
                
                with gr.Row():
                    width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
                    height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
                
                seed = gr.Number(value=-1, label="Seed (-1 for random)", precision=0)
                
                generate_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
            
            with gr.Column(scale=1):
                # Output
                gr.Markdown("### πŸ–ΌοΈ Generated Image")
                output_image = gr.Image(label="Result", height=400, show_label=False)
                
                # Visualizations
                gr.Markdown("### πŸ“Š Adapter Analysis")
                with gr.Row():
                    delta_l_img = gr.Image(label="CLIP-L Deltas", height=200)
                    gate_l_img = gr.Image(label="CLIP-L Gates", height=200)
                with gr.Row():
                    delta_g_img = gr.Image(label="CLIP-G Deltas", height=200)
                    gate_g_img = gr.Image(label="CLIP-G Gates", height=200)
                
                # Statistics
                gr.Markdown("### πŸ“ˆ Statistics")
                stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False)
                stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False)
        
        # Event handler
        def run_generation(*args):
            # Process adapter selections
            processed_args = list(args)
            processed_args[2] = None if args[2] == "None" else args[2]  # adapter_l
            processed_args[3] = None if args[3] == "None" else args[3]  # adapter_g
            return infer(*processed_args)
        
        generate_btn.click(
            fn=run_generation,
            inputs=[
                prompt, negative_prompt, adapter_l, adapter_g, strength, delta_scale,
                sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale,
                scheduler_name, width, height, seed
            ],
            outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text]
        )
    
    return demo

# ─── Launch ────────────────────────────────────────────────────
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()