import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from transformers import T5Tokenizer, T5EncoderModel from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler from safetensors.torch import load_file from huggingface_hub import hf_hub_download from two_stream_shunt_adapter import TwoStreamShuntAdapter from configs import T5_SHUNT_REPOS # ─── Device & Model Setup ───────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # T5 Model for semantic understanding t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() # SDXL Pipeline with proper text encoders pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16" if dtype == torch.float16 else None, use_safetensors=True ).to(device) # Available schedulers SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } # ─── Adapter Configs ────────────────────────────────────────── clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] config_l = T5_SHUNT_REPOS["clip_l"]["config"] config_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─── Loader ─────────────────────────────────────────────────── from safetensors.torch import safe_open def load_adapter(repo, filename, config): path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(config).eval() tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) model.load_state_dict(tensors) model.to(device) return model # ─── Visualization ──────────────────────────────────────────── def plot_heat(mat, title): import io fig, ax = plt.subplots(figsize=(6, 3), dpi=100) im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper") ax.set_title(title) plt.colorbar(im, ax=ax) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches='tight') buf.seek(0) plt.close(fig) return buf # ─── SDXL Text Encoding ─────────────────────────────────────── def encode_sdxl_prompt(prompt, negative_prompt=""): """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders""" # Tokenize for both encoders tokens_l = pipe.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) tokens_g = pipe.tokenizer_2( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) # Negative prompts neg_tokens_l = pipe.tokenizer( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) neg_tokens_g = pipe.tokenizer_2( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) with torch.no_grad(): # CLIP-L embeddings (768d) clip_l_embeds = pipe.text_encoder(tokens_l)[0] neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0] # CLIP-G embeddings (1280d) clip_g_embeds = pipe.text_encoder_2(tokens_g)[0] neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0] # Pooled embeddings for SDXL pooled_embeds = pipe.text_encoder_2(tokens_g)[1] neg_pooled_embeds = pipe.text_encoder_2(neg_tokens_g)[1] return { "clip_l": clip_l_embeds, "clip_g": clip_g_embeds, "neg_clip_l": neg_clip_l_embeds, "neg_clip_g": neg_clip_g_embeds, "pooled": pooled_embeds, "neg_pooled": neg_pooled_embeds } # ─── Inference ──────────────────────────────────────────────── @torch.no_grad() def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed): # Set seed for reproducibility if seed != -1: torch.manual_seed(seed) np.random.seed(seed) # Set scheduler if scheduler_name in SCHEDULERS: pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) # Get T5 embeddings for semantic understanding t5_ids = t5_tok( prompt, return_tensors="pt", padding="max_length", max_length=77, # Match CLIP's standard length truncation=True ).input_ids.to(device) print(t5_ids.shape) t5_seq = t5_mod(t5_ids).last_hidden_state # Get proper SDXL CLIP embeddings clip_embeds = encode_sdxl_prompt(prompt, negative_prompt) # Load adapters adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None # Apply CLIP-L adapter if adapter_l is not None: anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"]) gate_l_scaled = gate_l * gate_prob delta_l_final = delta_l * strength * gate_l_scaled clip_l_mod = clip_embeds["clip_l"] + delta_l_final if use_anchor: clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled if noise > 0: clip_l_mod += torch.randn_like(clip_l_mod) * noise else: clip_l_mod = clip_embeds["clip_l"] delta_l_final = torch.zeros_like(clip_embeds["clip_l"]) gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"]) g_pred_l = torch.tensor(0.0) tau_l = torch.tensor(0.0) # Apply CLIP-G adapter if adapter_g is not None: anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"]) gate_g_scaled = gate_g * gate_prob delta_g_final = delta_g * strength * gate_g_scaled clip_g_mod = clip_embeds["clip_g"] + delta_g_final if use_anchor: clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled if noise > 0: clip_g_mod += torch.randn_like(clip_g_mod) * noise else: clip_g_mod = clip_embeds["clip_g"] delta_g_final = torch.zeros_like(clip_embeds["clip_g"]) gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"]) g_pred_g = torch.tensor(0.0) tau_g = torch.tensor(0.0) # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048 prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype) neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype) # Generate image with proper SDXL parameters image = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_embeds["pooled"], negative_prompt_embeds=neg_embeds, negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None ).images[0] return ( image, plot_heat(delta_l_final.squeeze().cpu().numpy(), "Δ CLIP-L"), plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), plot_heat(delta_g_final.squeeze().cpu().numpy(), "Δ CLIP-G"), plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"), f"g_pred_l: {g_pred_l.mean().item():.3f}, τ_l: {tau_l.mean().item():.3f}", f"g_pred_g: {g_pred_g.mean().item():.3f}, τ_g: {tau_g.mean().item():.3f}" ) # ─── Gradio Interface ───────────────────────────────────────── with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual Shunt Adapter • T5→CLIP Enhancement") gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings") with gr.Row(): with gr.Column(scale=1): # Prompts with gr.Group(): gr.Markdown("### Prompts") prompt = gr.Textbox( label="Prompt", value="a futuristic control station with holographic displays", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt", value="blurry, low quality, distorted", lines=2 ) # Adapters with gr.Group(): gr.Markdown("### Adapters") adapter_l = gr.Dropdown( choices=["None"] + clip_l_opts, label="CLIP-L (768d) Adapter", value="None" ) adapter_g = gr.Dropdown( choices=["None"] + clip_g_opts, label="CLIP-G (1280d) Adapter", value="None" ) # Adapter Controls with gr.Group(): gr.Markdown("### Adapter Controls") strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") use_anchor = gr.Checkbox(label="Use Anchor", value=True) # Generation Settings with gr.Group(): gr.Markdown("### Generation Settings") with gr.Row(): steps = gr.Slider(1, 100, value=25, step=1, label="Steps") cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale") scheduler_name = gr.Dropdown( choices=list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler" ) with gr.Row(): width = gr.Slider(512, 1536, value=1024, step=64, label="Width") height = gr.Slider(512, 1536, value=1024, step=64, label="Height") seed = gr.Number(value=-1, label="Seed (-1 for random)") run_btn = gr.Button("🚀 Generate", variant="primary", size="lg") with gr.Column(scale=1): # Output with gr.Group(): gr.Markdown("### Generated Image") out_img = gr.Image(label="Result", height=400) # Visualizations with gr.Group(): gr.Markdown("### Adapter Visualizations") with gr.Row(): delta_l = gr.Image(label="Δ CLIP-L", height=200) gate_l = gr.Image(label="Gate CLIP-L", height=200) with gr.Row(): delta_g = gr.Image(label="Δ CLIP-G", height=200) gate_g = gr.Image(label="Gate CLIP-G", height=200) # Stats with gr.Group(): gr.Markdown("### Adapter Statistics") stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False) stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False) # Event handlers def process_adapters(adapter_l_val, adapter_g_val): # Convert "None" back to None for processing adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val return adapter_l_processed, adapter_g_processed def run_inference(*args): # Process adapter selections adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3]) # Call inference with processed adapters new_args = list(args) new_args[2] = adapter_l_processed new_args[3] = adapter_g_processed return infer(*new_args) run_btn.click( fn=run_inference, inputs=[ prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed ], outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] ) if __name__ == "__main__": demo.launch(share=True)