import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from transformers import T5Tokenizer, T5EncoderModel from diffusers import DiffusionPipeline from safetensors.torch import load_file from huggingface_hub import hf_hub_download from two_stream_shunt_adapter import TwoStreamShuntAdapter from configs import T5_SHUNT_REPOS # ─── Device & Model Setup ───────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16" if dtype == torch.float16 else None ).to(device) # ─── Adapter Configs ────────────────────────────────────────── clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] config_l = T5_SHUNT_REPOS["clip_l"]["config"] config_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─── Loader ─────────────────────────────────────────────────── from safetensors.torch import safe_open def load_adapter(repo, filename, config): path = hf_hub_download(repo_id=repo, filename=filename) # Fallback-safe loading for ZeroGPU model = TwoStreamShuntAdapter(config).eval() tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) model.load_state_dict(tensors) model.to(device) return model # ─── Visualization ──────────────────────────────────────────── def plot_heat(mat, title): import io fig, ax = plt.subplots(figsize=(6, 3), dpi=100) im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper") ax.set_title(title) plt.colorbar(im, ax=ax) buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches='tight') buf.seek(0) return buf # ─── Inference ──────────────────────────────────────────────── @torch.no_grad() def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor): t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device) t5_seq = t5_mod(t5_ids).last_hidden_state adapter_l = load_adapter(repo_l, adapter_l_file, config_l) adapter_g = load_adapter(repo_g, adapter_g_file, config_g) clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device) clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device) anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in) gate_l_scaled = gate_l * gate_prob delta_l_final = delta_l * strength * gate_l_scaled clip_l_mod = clip_l_in + delta_l_final if use_anchor: clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled if noise > 0: clip_l_mod += torch.randn_like(clip_l_mod) * noise anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in) gate_g_scaled = gate_g * gate_prob delta_g_final = delta_g * strength * gate_g_scaled clip_g_mod = clip_g_in + delta_g_final if use_anchor: clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled if noise > 0: clip_g_mod += torch.randn_like(clip_g_mod) * noise prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype) neg_embeds = torch.zeros_like(prompt_embeds) image = pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_embeds, num_inference_steps=20, guidance_scale=5.0 ).images[0] return ( image, plot_heat(delta_l_final.squeeze().cpu().numpy(), "Δ CLIP-L"), plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), plot_heat(delta_g_final.squeeze().cpu().numpy(), "Δ CLIP-G"), plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"), f"g_pred_l: {g_pred_l.mean().item():.3f}, τ_l: {tau_l.mean().item():.3f}", f"g_pred_g: {g_pred_g.mean().item():.3f}, τ_g: {tau_g.mean().item():.3f}" ) # ─── Gradio App ─────────────────────────────────────────────── with gr.Blocks(title="Dual Adapter T5→CLIP") as demo: gr.Markdown("# 🧠 Dual Shunt Adapter • SDXL Inference") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="a futuristic control station") adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter") adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter") strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") use_anchor = gr.Checkbox(label="Use Anchor", value=True) run_btn = gr.Button("Run") with gr.Column(): out_img = gr.Image(label="Generated Image") delta_l = gr.Image(label="Δ CLIP-L") gate_l = gr.Image(label="Gate CLIP-L") delta_g = gr.Image(label="Δ CLIP-G") gate_g = gr.Image(label="Gate CLIP-G") stats_l = gr.Textbox(label="CLIP-L Stats") stats_g = gr.Textbox(label="CLIP-G Stats") run_btn.click( fn=infer, inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor], outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] ) if __name__ == "__main__": demo.launch()