AbstractPhil's picture
yes
7229198
raw
history blame
6.47 kB
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from configs import T5_SHUNT_REPOS
# ─── Device & Model Setup ─────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None
).to(device)
# ─── Adapter Configs ──────────────────────────────────────────
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
# ─── Loader ───────────────────────────────────────────────────
from safetensors.torch import safe_open
def load_adapter(repo, filename, config):
path = hf_hub_download(repo_id=repo, filename=filename)
# Fallback-safe loading for ZeroGPU
model = TwoStreamShuntAdapter(config).eval()
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
model.load_state_dict(tensors)
model.to(device)
return model
# ─── Visualization ────────────────────────────────────────────
def plot_heat(mat, title):
import io
fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
ax.set_title(title)
plt.colorbar(im, ax=ax)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches='tight')
buf.seek(0)
return buf
# ─── Inference ────────────────────────────────────────────────
@torch.no_grad()
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
t5_seq = t5_mod(t5_ids).last_hidden_state
adapter_l = load_adapter(repo_l, adapter_l_file, config_l)
adapter_g = load_adapter(repo_g, adapter_g_file, config_g)
clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device)
clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device)
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
gate_l_scaled = gate_l * gate_prob
delta_l_final = delta_l * strength * gate_l_scaled
clip_l_mod = clip_l_in + delta_l_final
if use_anchor:
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
if noise > 0:
clip_l_mod += torch.randn_like(clip_l_mod) * noise
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
gate_g_scaled = gate_g * gate_prob
delta_g_final = delta_g * strength * gate_g_scaled
clip_g_mod = clip_g_in + delta_g_final
if use_anchor:
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
if noise > 0:
clip_g_mod += torch.randn_like(clip_g_mod) * noise
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
neg_embeds = torch.zeros_like(prompt_embeds)
image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=neg_embeds,
num_inference_steps=20,
guidance_scale=5.0
).images[0]
return (
image,
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
)
# ─── Gradio App ───────────────────────────────────────────────
with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
gr.Markdown("# 🧠 Dual Shunt Adapter β€’ SDXL Inference")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
run_btn = gr.Button("Run")
with gr.Column():
out_img = gr.Image(label="Generated Image")
delta_l = gr.Image(label="Ξ” CLIP-L")
gate_l = gr.Image(label="Gate CLIP-L")
delta_g = gr.Image(label="Ξ” CLIP-G")
gate_g = gr.Image(label="Gate CLIP-G")
stats_l = gr.Textbox(label="CLIP-L Stats")
stats_g = gr.Textbox(label="CLIP-G Stats")
run_btn.click(
fn=infer,
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
)
if __name__ == "__main__":
demo.launch()