Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import DiffusionPipeline | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from configs import T5_SHUNT_REPOS | |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
variant="fp16" if dtype == torch.float16 else None | |
).to(device) | |
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ | |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] | |
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] | |
config_l = T5_SHUNT_REPOS["clip_l"]["config"] | |
config_g = T5_SHUNT_REPOS["clip_g"]["config"] | |
# βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from safetensors.torch import safe_open | |
def load_adapter(repo, filename, config): | |
path = hf_hub_download(repo_id=repo, filename=filename) | |
# Fallback-safe loading for ZeroGPU | |
model = TwoStreamShuntAdapter(config).eval() | |
tensors = {} | |
with safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
model.load_state_dict(tensors) | |
model.to(device) | |
return model | |
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ | |
def plot_heat(mat, title): | |
import io | |
fig, ax = plt.subplots(figsize=(6, 3), dpi=100) | |
im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper") | |
ax.set_title(title) | |
plt.colorbar(im, ax=ax) | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png", bbox_inches='tight') | |
buf.seek(0) | |
return buf | |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor): | |
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device) | |
t5_seq = t5_mod(t5_ids).last_hidden_state | |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) | |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) | |
clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device) | |
clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device) | |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in) | |
gate_l_scaled = gate_l * gate_prob | |
delta_l_final = delta_l * strength * gate_l_scaled | |
clip_l_mod = clip_l_in + delta_l_final | |
if use_anchor: | |
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled | |
if noise > 0: | |
clip_l_mod += torch.randn_like(clip_l_mod) * noise | |
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in) | |
gate_g_scaled = gate_g * gate_prob | |
delta_g_final = delta_g * strength * gate_g_scaled | |
clip_g_mod = clip_g_in + delta_g_final | |
if use_anchor: | |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled | |
if noise > 0: | |
clip_g_mod += torch.randn_like(clip_g_mod) * noise | |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype) | |
neg_embeds = torch.zeros_like(prompt_embeds) | |
image = pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=neg_embeds, | |
num_inference_steps=20, | |
guidance_scale=5.0 | |
).images[0] | |
return ( | |
image, | |
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"), | |
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), | |
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"), | |
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"), | |
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}", | |
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}" | |
) | |
# βββ Gradio App βββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(title="Dual Adapter T5βCLIP") as demo: | |
gr.Markdown("# π§ Dual Shunt Adapter β’ SDXL Inference") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="a futuristic control station") | |
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter") | |
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter") | |
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") | |
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") | |
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") | |
use_anchor = gr.Checkbox(label="Use Anchor", value=True) | |
run_btn = gr.Button("Run") | |
with gr.Column(): | |
out_img = gr.Image(label="Generated Image") | |
delta_l = gr.Image(label="Ξ CLIP-L") | |
gate_l = gr.Image(label="Gate CLIP-L") | |
delta_g = gr.Image(label="Ξ CLIP-G") | |
gate_g = gr.Image(label="Gate CLIP-G") | |
stats_l = gr.Textbox(label="CLIP-L Stats") | |
stats_g = gr.Textbox(label="CLIP-G Stats") | |
run_btn.click( | |
fn=infer, | |
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor], | |
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |