import spaces import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from PIL import Image from transformers import T5Tokenizer, T5EncoderModel from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler from safetensors.torch import load_file from huggingface_hub import hf_hub_download from two_stream_shunt_adapter import TwoStreamShuntAdapter from configs import T5_SHUNT_REPOS import io # ─── Global Variables ───────────────────────────────────────── t5_tok = None t5_mod = None pipe = None # Available schedulers SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } # ─── Adapter Configs ────────────────────────────────────────── clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] config_l = T5_SHUNT_REPOS["clip_l"]["config"] config_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─── Helper Functions ───────────────────────────────────────── def load_adapter(repo, filename, config, device): """Load adapter from safetensors file""" from safetensors.torch import safe_open path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(config).eval() tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) model.load_state_dict(tensors) return model.to(device) def plot_heat(mat, title): """Create heatmap visualization with proper shape handling""" # Handle different input shapes if isinstance(mat, torch.Tensor): mat = mat.detach().cpu().numpy() # Ensure we have a 2D array for visualization if len(mat.shape) == 1: # 1D array - reshape to single row mat = mat.reshape(1, -1) elif len(mat.shape) == 3: # 3D array - average over batch dimension if mat.shape[0] == 1: mat = mat.squeeze(0) else: mat = mat.mean(axis=0) elif len(mat.shape) > 3: # Flatten higher dimensions mat = mat.reshape(-1, mat.shape[-1]) # Create figure with proper DPI plt.figure(figsize=(8, 4), dpi=100) plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest') plt.title(title, fontsize=12, fontweight='bold') plt.xlabel("Token Position") plt.ylabel("Feature Dimension") plt.colorbar(shrink=0.8) plt.tight_layout() # Convert to PIL Image buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches='tight', dpi=100) buf.seek(0) pil_image = Image.open(buf) plt.close() # Convert to numpy array for Gradio return np.array(pil_image) def encode_sdxl_prompt(pipe, prompt, negative_prompt, device): """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders""" # Tokenize for both encoders tokens_l = pipe.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) tokens_g = pipe.tokenizer_2( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) neg_tokens_l = pipe.tokenizer( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) neg_tokens_g = pipe.tokenizer_2( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) with torch.no_grad(): # CLIP-L: [0] = sequence, [1] = pooled clip_l_output = pipe.text_encoder(tokens_l, output_hidden_states=False) clip_l_embeds = clip_l_output[0] neg_clip_l_output = pipe.text_encoder(neg_tokens_l, output_hidden_states=False) neg_clip_l_embeds = neg_clip_l_output[0] # CLIP-G: [0] = pooled, [1] = sequence clip_g_output = pipe.text_encoder_2(tokens_g, output_hidden_states=False) clip_g_embeds = clip_g_output[1] # sequence embeddings pooled_embeds = clip_g_output[0] # pooled embeddings neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g, output_hidden_states=False) neg_clip_g_embeds = neg_clip_g_output[1] neg_pooled_embeds = neg_clip_g_output[0] return { "clip_l": clip_l_embeds, "clip_g": clip_g_embeds, "neg_clip_l": neg_clip_l_embeds, "neg_clip_g": neg_clip_g_embeds, "pooled": pooled_embeds, "neg_pooled": neg_pooled_embeds } # ─── Main Inference Function ────────────────────────────────── @spaces.GPU def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, delta_scale, sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed): global t5_tok, t5_mod, pipe device = torch.device("cuda") dtype = torch.float16 # Initialize models if t5_tok is None: t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() if pipe is None: pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True ).to(device) # Set seed if seed != -1: torch.manual_seed(seed) np.random.seed(seed) generator = torch.Generator(device=device).manual_seed(seed) else: generator = None # Set scheduler if scheduler_name in SCHEDULERS: pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) # Get T5 embeddings t5_ids = t5_tok( prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True ).input_ids.to(device) with torch.no_grad(): t5_seq = t5_mod(t5_ids).last_hidden_state # Get CLIP embeddings clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device) # Load and apply adapters if(adapter_l_file == "t5-vit-l-14-dual_shunt_booru_13_000_000.safetensors" or adapter_l_file == "t5-vit-l-14-dual_shunt_booru_51_200_000.safetensors"): config_l["heads"] = 4 else: config_l["heads"] = 12 adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None # Apply CLIP-L adapter if adapter_l is not None: with torch.no_grad(): # Run adapter forward pass adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float()) # Unpack outputs (ensure correct number of outputs) if len(adapter_output) == 8: anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output else: # Handle different return formats anchor_l = adapter_output[0] delta_l = adapter_output[1] log_sigma_l = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_l) gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l) tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0) g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0) # Scale delta values delta_l = delta_l * delta_scale # Apply g_pred scaling to gate gate_l = gate_l * g_pred_l * gpred_scale # Apply gate scaling gate_l_scaled = torch.sigmoid(gate_l) * gate_prob # Compute final delta with strength and gate delta_l_final = delta_l * strength * gate_l_scaled # Apply delta to embeddings clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype) # Apply sigma-based noise if specified if sigma_scale > 0: sigma_l = torch.exp(log_sigma_l * sigma_scale) clip_l_mod += torch.randn_like(clip_l_mod) * sigma_l.to(dtype) # Apply anchor mixing if enabled if use_anchor: clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype) # Add additional noise if specified if noise > 0: clip_l_mod += torch.randn_like(clip_l_mod) * noise else: clip_l_mod = clip_embeds["clip_l"] delta_l_final = torch.zeros_like(clip_embeds["clip_l"]) gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"]) g_pred_l = torch.tensor(0.0) tau_l = torch.tensor(0.0) # Apply CLIP-G adapter if adapter_g is not None: with torch.no_grad(): # Run adapter forward pass adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float()) # Unpack outputs (ensure correct number of outputs) if len(adapter_output) == 8: anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output else: # Handle different return formats anchor_g = adapter_output[0] delta_g = adapter_output[1] log_sigma_g = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_g) gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g) tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0) g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0) # Scale delta values delta_g = delta_g * delta_scale # Apply g_pred scaling to gate gate_g = gate_g * g_pred_g * gpred_scale # Apply gate scaling gate_g_scaled = torch.sigmoid(gate_g) * gate_prob # Compute final delta with strength and gate delta_g_final = delta_g * strength * gate_g_scaled # Apply delta to embeddings clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype) # Apply sigma-based noise if specified if sigma_scale > 0: sigma_g = torch.exp(log_sigma_g * sigma_scale) clip_g_mod += torch.randn_like(clip_g_mod) * sigma_g.to(dtype) # Apply anchor mixing if enabled if use_anchor: clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype) # Add additional noise if specified if noise > 0: clip_g_mod += torch.randn_like(clip_g_mod) * noise else: clip_g_mod = clip_embeds["clip_g"] delta_g_final = torch.zeros_like(clip_embeds["clip_g"]) gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"]) g_pred_g = torch.tensor(0.0) tau_g = torch.tensor(0.0) # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048 prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1) # Generate image image = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_embeds["pooled"], negative_prompt_embeds=neg_embeds, negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, num_images_per_prompt=1, generator=generator ).images[0] # Create visualizations delta_l_viz = plot_heat(delta_l_final.squeeze(), "CLIP-L Delta Values") gate_l_viz = plot_heat(gate_l_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-L Gate Activations") delta_g_viz = plot_heat(delta_g_final.squeeze(), "CLIP-G Delta Values") gate_g_viz = plot_heat(gate_g_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-G Gate Activations") # Statistics stats_l = f"g_pred_l: {float(g_pred_l.mean().item() if hasattr(g_pred_l, 'mean') else g_pred_l):.3f}, τ_l: {float(tau_l.mean().item() if hasattr(tau_l, 'mean') else tau_l):.3f}" stats_g = f"g_pred_g: {float(g_pred_g.mean().item() if hasattr(g_pred_g, 'mean') else g_pred_g):.3f}, τ_g: {float(tau_g.mean().item() if hasattr(tau_g, 'mean') else tau_g):.3f}" return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g # ─── Gradio Interface ───────────────────────────────────────── def create_interface(): with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual Shunt Adapter") gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*") with gr.Row(): with gr.Column(scale=1): # Prompts gr.Markdown("### 📝 Prompts") prompt = gr.Textbox( label="Prompt", value="a futuristic control station with holographic displays", lines=3, placeholder="Describe what you want to generate..." ) negative_prompt = gr.Textbox( label="Negative Prompt", value="blurry, low quality, distorted", lines=2, placeholder="Describe what you want to avoid..." ) # Adapters gr.Markdown("### ⚙️ Adapters") adapter_l = gr.Dropdown( choices=["None"] + clip_l_opts, label="CLIP-L (768d) Adapter", value="t5-vit-l-14-dual_shunt_caption.safetensors", info="Choose adapter for CLIP-L embeddings" ) adapter_g = gr.Dropdown( choices=["None"] + clip_g_opts, label="CLIP-G (1280d) Adapter", value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors", info="Choose adapter for CLIP-G embeddings" ) # Controls gr.Markdown("### 🎛️ Adapter Controls") strength = gr.Slider(0.0, 10.0, value=4.0, step=0.01, label="Adapter Strength") delta_scale = gr.Slider(-15.0, 15.0, value=0.2, step=0.1, label="Delta Scale", info="Scales the delta values, recommended 1") sigma_scale = gr.Slider(0, 15.0, value=0.1, step=0.1, label="Sigma Scale", info="Scales the noise variance, recommended 1") gpred_scale = gr.Slider(0.0, 20.0, value=2.0, step=0.01, label="G-Pred Scale", info="Scales the gate prediction, recommended 2") noise = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Noise Injection") gate_prob = gr.Slider(0.0, 1.0, value=0.27, step=0.01, label="Gate Probability") use_anchor = gr.Checkbox(label="Use Anchor Points", value=True) # Generation Settings gr.Markdown("### 🎨 Generation Settings") with gr.Row(): steps = gr.Slider(1, 50, value=20, step=1, label="Steps") cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="CFG Scale") scheduler_name = gr.Dropdown( choices=list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler" ) with gr.Row(): width = gr.Slider(512, 1536, value=1024, step=64, label="Width") height = gr.Slider(512, 1536, value=1024, step=64, label="Height") seed = gr.Number(value=-1, label="Seed (-1 for random)", precision=0) generate_btn = gr.Button("🚀 Generate Image", variant="primary", size="lg") with gr.Column(scale=1): # Output gr.Markdown("### 🖼️ Generated Image") output_image = gr.Image(label="Result", height=400, show_label=False) # Visualizations gr.Markdown("### 📊 Adapter Analysis") with gr.Row(): delta_l_img = gr.Image(label="CLIP-L Deltas", height=200) gate_l_img = gr.Image(label="CLIP-L Gates", height=200) with gr.Row(): delta_g_img = gr.Image(label="CLIP-G Deltas", height=200) gate_g_img = gr.Image(label="CLIP-G Gates", height=200) # Statistics gr.Markdown("### 📈 Statistics") stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False) stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False) # Event handler def run_generation(*args): # Process adapter selections processed_args = list(args) processed_args[2] = None if args[2] == "None" else args[2] # adapter_l processed_args[3] = None if args[3] == "None" else args[3] # adapter_g return infer(*processed_args) generate_btn.click( fn=run_generation, inputs=[ prompt, negative_prompt, adapter_l, adapter_g, strength, delta_scale, sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed ], outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text] ) return demo # ─── Launch ──────────────────────────────────────────────────── if __name__ == "__main__": demo = create_interface() demo.launch()