import spaces import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from PIL import Image from transformers import T5Tokenizer, T5EncoderModel from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler from safetensors.torch import load_file from huggingface_hub import hf_hub_download from two_stream_shunt_adapter import TwoStreamShuntAdapter from configs import T5_SHUNT_REPOS # ─── Global Variables ───────────────────────────────────────── t5_tok = None t5_mod = None pipe = None # Available schedulers SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } # ─── Adapter Configs ────────────────────────────────────────── clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] config_l = T5_SHUNT_REPOS["clip_l"]["config"] config_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─── Helper Functions ───────────────────────────────────────── def load_adapter(repo, filename, config, device): from safetensors.torch import safe_open path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(config).eval() tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) model.load_state_dict(tensors) return model.to(device) def plot_heat(mat, title): """Create heatmap visualization with proper shape handling""" import io # Ensure we have a 2D array for visualization if len(mat.shape) == 1: mat = mat.reshape(1, -1) elif len(mat.shape) == 3: mat = mat.mean(axis=0) elif len(mat.shape) > 3: mat = mat.reshape(-1, mat.shape[-1]) fig, ax = plt.subplots(figsize=(8, 4), dpi=100) im = ax.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") ax.set_title(title, fontsize=12, fontweight='bold') ax.set_xlabel("Token Position") ax.set_ylabel("Feature Dimension") plt.colorbar(im, ax=ax, shrink=0.8) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches='tight', dpi=100) buf.seek(0) pil_image = Image.open(buf) plt.close(fig) return pil_image def encode_sdxl_prompt(pipe, prompt, negative_prompt, device): """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders""" # Tokenize for both encoders tokens_l = pipe.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) tokens_g = pipe.tokenizer_2( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) neg_tokens_l = pipe.tokenizer( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) neg_tokens_g = pipe.tokenizer_2( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) with torch.no_grad(): # CLIP-L: [0] = sequence, [1] = pooled clip_l_embeds = pipe.text_encoder(tokens_l)[0] neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0] # CLIP-G: [0] = pooled, [1] = sequence (different from CLIP-L!) clip_g_output = pipe.text_encoder_2(tokens_g) clip_g_embeds = clip_g_output[1] # sequence embeddings pooled_embeds = clip_g_output[0] # pooled embeddings neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g) neg_clip_g_embeds = neg_clip_g_output[1] neg_pooled_embeds = neg_clip_g_output[0] return { "clip_l": clip_l_embeds, "clip_g": clip_g_embeds, "neg_clip_l": neg_clip_l_embeds, "neg_clip_g": neg_clip_g_embeds, "pooled": pooled_embeds, "neg_pooled": neg_pooled_embeds } # ─── Main Inference Function ────────────────────────────────── @spaces.GPU def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed): global t5_tok, t5_mod, pipe device = torch.device("cuda") dtype = torch.float16 # Initialize models if t5_tok is None: t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() if pipe is None: pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True ).to(device) # Set seed if seed != -1: torch.manual_seed(seed) np.random.seed(seed) # Set scheduler if scheduler_name in SCHEDULERS: pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) # Get T5 embeddings t5_ids = t5_tok( prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True ).input_ids.to(device) t5_seq = t5_mod(t5_ids).last_hidden_state # Get CLIP embeddings clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device) # Load and apply adapters adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None # Apply CLIP-L adapter if adapter_l is not None: anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l( t5_seq.float(), clip_embeds["clip_l"].float() ) gate_l_scaled = gate_l * gate_prob delta_l_final = delta_l * strength * gate_l_scaled clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype) if use_anchor: clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype) if noise > 0: clip_l_mod += torch.randn_like(clip_l_mod) * noise else: clip_l_mod = clip_embeds["clip_l"] delta_l_final = torch.zeros_like(clip_embeds["clip_l"]) gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"]) g_pred_l, tau_l = torch.tensor(0.0), torch.tensor(0.0) # Apply CLIP-G adapter if adapter_g is not None: anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g( t5_seq.float(), clip_embeds["clip_g"].float() ) gate_g_scaled = gate_g * gate_prob delta_g_final = delta_g * strength * gate_g_scaled clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype) if use_anchor: clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype) if noise > 0: clip_g_mod += torch.randn_like(clip_g_mod) * noise else: clip_g_mod = clip_embeds["clip_g"] delta_g_final = torch.zeros_like(clip_embeds["clip_g"]) gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"]) g_pred_g, tau_g = torch.tensor(0.0), torch.tensor(0.0) # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048 prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1) # Generate image image = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_embeds["pooled"], negative_prompt_embeds=neg_embeds, negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, num_images_per_prompt=1, generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None ).images[0] # Create visualizations delta_l_viz = plot_heat(delta_l_final.squeeze().cpu().numpy(), "CLIP-L Delta Values") gate_l_viz = plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-L Gate Activations") delta_g_viz = plot_heat(delta_g_final.squeeze().cpu().numpy(), "CLIP-G Delta Values") gate_g_viz = plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-G Gate Activations") # Statistics stats_l = f"g_pred_l: {g_pred_l.mean().item():.3f}, τ_l: {tau_l.mean().item():.3f}" stats_g = f"g_pred_g: {g_pred_g.mean().item():.3f}, τ_g: {tau_g.mean().item():.3f}" return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g # ─── Gradio Interface ───────────────────────────────────────── def create_interface(): with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual Shunt Adapter") gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*") with gr.Row(): with gr.Column(scale=1): # Prompts gr.Markdown("### 📝 Prompts") prompt = gr.Textbox( label="Prompt", value="a futuristic control station with holographic displays", lines=3, placeholder="Describe what you want to generate..." ) negative_prompt = gr.Textbox( label="Negative Prompt", value="blurry, low quality, distorted", lines=2, placeholder="Describe what you want to avoid..." ) # Adapters gr.Markdown("### ⚙️ Adapters") adapter_l = gr.Dropdown( choices=["None"] + clip_l_opts, label="CLIP-L (768d) Adapter", value="None", info="Choose adapter for CLIP-L embeddings" ) adapter_g = gr.Dropdown( choices=["None"] + clip_g_opts, label="CLIP-G (1280d) Adapter", value="None", info="Choose adapter for CLIP-G embeddings" ) # Controls gr.Markdown("### 🎛️ Adapter Controls") strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") use_anchor = gr.Checkbox(label="Use Anchor Points", value=True) # Generation Settings gr.Markdown("### 🎨 Generation Settings") with gr.Row(): steps = gr.Slider(1, 50, value=25, step=1, label="Steps") cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale") scheduler_name = gr.Dropdown( choices=list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler" ) with gr.Row(): width = gr.Slider(512, 1536, value=1024, step=64, label="Width") height = gr.Slider(512, 1536, value=1024, step=64, label="Height") seed = gr.Number(value=-1, label="Seed (-1 for random)") generate_btn = gr.Button("🚀 Generate Image", variant="primary", size="lg") with gr.Column(scale=1): # Output gr.Markdown("### 🖼️ Generated Image") output_image = gr.Image(label="Result", height=400, show_label=False) # Visualizations gr.Markdown("### 📊 Adapter Analysis") with gr.Row(): delta_l_img = gr.Image(label="CLIP-L Deltas", height=200) gate_l_img = gr.Image(label="CLIP-L Gates", height=200) with gr.Row(): delta_g_img = gr.Image(label="CLIP-G Deltas", height=200) gate_g_img = gr.Image(label="CLIP-G Gates", height=200) # Statistics gr.Markdown("### 📈 Statistics") stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False) stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False) # Event handler def run_generation(*args): # Process adapter selections processed_args = list(args) processed_args[2] = None if args[2] == "None" else args[2] # adapter_l processed_args[3] = None if args[3] == "None" else args[3] # adapter_g return infer(*processed_args) generate_btn.click( fn=run_generation, inputs=[ prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed ], outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text] ) return demo # ─── Launch ──────────────────────────────────────────────────── if __name__ == "__main__": demo = create_interface() demo.launch()