import spaces import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from PIL import Image import spaces from transformers import T5Tokenizer, T5EncoderModel from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler from safetensors.torch import load_file from huggingface_hub import hf_hub_download from two_stream_shunt_adapter import TwoStreamShuntAdapter from configs import T5_SHUNT_REPOS # ─── Device & Model Setup ───────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # T5 Model for semantic understanding t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() # SDXL Pipeline with proper text encoders pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16" if dtype == torch.float16 else None, use_safetensors=True ).to(device) # Available schedulers SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } # ─── Adapter Configs ────────────────────────────────────────── clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] config_l = T5_SHUNT_REPOS["clip_l"]["config"] config_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─── Loader ─────────────────────────────────────────────────── from safetensors.torch import safe_open def load_adapter(repo, filename, config): path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(config).eval() tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) model.load_state_dict(tensors) model.to(device) return model # ─── Visualization ──────────────────────────────────────────── def plot_heat(mat, title): import io fig, ax = plt.subplots(figsize=(6, 3), dpi=100) im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper") ax.set_title(title) plt.colorbar(im, ax=ax) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches='tight') buf.seek(0) plt.close(fig) return buf # ─── SDXL Text Encoding ─────────────────────────────────────── def encode_sdxl_prompt(prompt, negative_prompt=""): """Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders""" # Tokenize for both encoders tokens_l = pipe.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) tokens_g = pipe.tokenizer_2( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) # Negative prompts neg_tokens_l = pipe.tokenizer( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) neg_tokens_g = pipe.tokenizer_2( negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).input_ids.to(device) with torch.no_grad(): # CLIP-L embeddings (768d) - works fine clip_l_embeds = pipe.text_encoder(tokens_l)[0] neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0] # CLIP-G embeddings (1280d) - [0] is pooled, [1] is sequence (opposite of CLIP-L) clip_g_output = pipe.text_encoder_2(tokens_g) clip_g_embeds = clip_g_output[1] # sequence embeddings neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g) neg_clip_g_embeds = neg_clip_g_output[1] # sequence embeddings # Pooled embeddings for SDXL pooled_embeds = clip_g_output[0] # pooled embeddings neg_pooled_embeds = neg_clip_g_output[0] # pooled embeddings return { "clip_l": clip_l_embeds, "clip_g": clip_g_embeds, "neg_clip_l": neg_clip_l_embeds, "neg_clip_g": neg_clip_g_embeds, "pooled": pooled_embeds, "neg_pooled": neg_pooled_embeds } # ─── Inference ──────────────────────────────────────────────── @spaces.GPU @torch.no_grad() def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed): # Set seed for reproducibility if seed != -1: torch.manual_seed(seed) np.random.seed(seed) # Set scheduler if scheduler_name in SCHEDULERS: pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) # Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP t5_ids = t5_tok( prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True ).input_ids.to(device) t5_seq = t5_mod(t5_ids).last_hidden_state # Get proper SDXL CLIP embeddings clip_embeds = encode_sdxl_prompt(prompt, negative_prompt) # Debug shapes print(f"T5 seq shape: {t5_seq.shape}") print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}") print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}") # Load adapters adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None # Apply CLIP-L adapter if adapter_l is not None: anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"]) gate_l_scaled = gate_l * gate_prob delta_l_final = delta_l * strength * gate_l_scaled clip_l_mod = clip_embeds["clip_l"] + delta_l_final if use_anchor: clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled if noise > 0: clip_l_mod += torch.randn_like(clip_l_mod) * noise else: clip_l_mod = clip_embeds["clip_l"] delta_l_final = torch.zeros_like(clip_embeds["clip_l"]) gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"]) g_pred_l = torch.tensor(0.0) tau_l = torch.tensor(0.0) # Apply CLIP-G adapter if adapter_g is not None: anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"]) gate_g_scaled = gate_g * gate_prob delta_g_final = delta_g * strength * gate_g_scaled clip_g_mod = clip_embeds["clip_g"] + delta_g_final if use_anchor: clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled if noise > 0: clip_g_mod += torch.randn_like(clip_g_mod) * noise else: clip_g_mod = clip_embeds["clip_g"] delta_g_final = torch.zeros_like(clip_embeds["clip_g"]) gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"]) g_pred_g = torch.tensor(0.0) tau_g = torch.tensor(0.0) # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048 prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype) neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype) # Generate image with proper SDXL parameters image = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_embeds["pooled"], negative_prompt_embeds=neg_embeds, negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, num_images_per_prompt=1, # Explicitly set this generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None ).images[0] return ( image, plot_heat(delta_l_final.squeeze().cpu().numpy(), "Δ CLIP-L"), plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), plot_heat(delta_g_final.squeeze().cpu().numpy(), "Δ CLIP-G"), plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"), f"g_pred_l: {g_pred_l.mean().item():.3f}, τ_l: {tau_l.mean().item():.3f}", f"g_pred_g: {g_pred_g.mean().item():.3f}, τ_g: {tau_g.mean().item():.3f}" ) # ─── Gradio Interface ───────────────────────────────────────── with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual Shunt Adapter • T5→CLIP Enhancement") gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings") with gr.Row(): with gr.Column(scale=1): # Prompts with gr.Group(): gr.Markdown("### Prompts") prompt = gr.Textbox( label="Prompt", value="a futuristic control station with holographic displays", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt", value="blurry, low quality, distorted", lines=2 ) # Adapters with gr.Group(): gr.Markdown("### Adapters") adapter_l = gr.Dropdown( choices=["None"] + clip_l_opts, label="CLIP-L (768d) Adapter", value="None" ) adapter_g = gr.Dropdown( choices=["None"] + clip_g_opts, label="CLIP-G (1280d) Adapter", value="None" ) # Adapter Controls with gr.Group(): gr.Markdown("### Adapter Controls") strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") use_anchor = gr.Checkbox(label="Use Anchor", value=True) # Generation Settings with gr.Group(): gr.Markdown("### Generation Settings") with gr.Row(): steps = gr.Slider(1, 100, value=25, step=1, label="Steps") cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale") scheduler_name = gr.Dropdown( choices=list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler" ) with gr.Row(): width = gr.Slider(512, 1536, value=1024, step=64, label="Width") height = gr.Slider(512, 1536, value=1024, step=64, label="Height") seed = gr.Number(value=-1, label="Seed (-1 for random)") run_btn = gr.Button("🚀 Generate", variant="primary", size="lg") with gr.Column(scale=1): # Output with gr.Group(): gr.Markdown("### Generated Image") out_img = gr.Image(label="Result", height=400) # Visualizations with gr.Group(): gr.Markdown("### Adapter Visualizations") with gr.Row(): delta_l = gr.Image(label="Δ CLIP-L", height=200) gate_l = gr.Image(label="Gate CLIP-L", height=200) with gr.Row(): delta_g = gr.Image(label="Δ CLIP-G", height=200) gate_g = gr.Image(label="Gate CLIP-G", height=200) # Stats with gr.Group(): gr.Markdown("### Adapter Statistics") stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False) stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False) # Event handlers def process_adapters(adapter_l_val, adapter_g_val): # Convert "None" back to None for processing adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val return adapter_l_processed, adapter_g_processed def run_inference(*args): # Process adapter selections adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3]) # Call inference with processed adapters new_args = list(args) new_args[2] = adapter_l_processed new_args[3] = adapter_g_processed return infer(*new_args) run_btn.click( fn=run_inference, inputs=[ prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed ], outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] ) if __name__ == "__main__": demo.launch()