Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,467 Bytes
ca066a9 403ae01 ca066a9 403ae01 ca066a9 620a643 ca066a9 7229198 ca066a9 7229198 ca066a9 7229198 ca066a9 403ae01 ca066a9 403ae01 ca066a9 403ae01 ca066a9 403ae01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from configs import T5_SHUNT_REPOS
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None
).to(device)
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
# βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ
from safetensors.torch import safe_open
def load_adapter(repo, filename, config):
path = hf_hub_download(repo_id=repo, filename=filename)
# Fallback-safe loading for ZeroGPU
model = TwoStreamShuntAdapter(config).eval()
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
model.load_state_dict(tensors)
model.to(device)
return model
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ
def plot_heat(mat, title):
import io
fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
ax.set_title(title)
plt.colorbar(im, ax=ax)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches='tight')
buf.seek(0)
return buf
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
@torch.no_grad()
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
t5_seq = t5_mod(t5_ids).last_hidden_state
adapter_l = load_adapter(repo_l, adapter_l_file, config_l)
adapter_g = load_adapter(repo_g, adapter_g_file, config_g)
clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device)
clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device)
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
gate_l_scaled = gate_l * gate_prob
delta_l_final = delta_l * strength * gate_l_scaled
clip_l_mod = clip_l_in + delta_l_final
if use_anchor:
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
if noise > 0:
clip_l_mod += torch.randn_like(clip_l_mod) * noise
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
gate_g_scaled = gate_g * gate_prob
delta_g_final = delta_g * strength * gate_g_scaled
clip_g_mod = clip_g_in + delta_g_final
if use_anchor:
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
if noise > 0:
clip_g_mod += torch.randn_like(clip_g_mod) * noise
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
neg_embeds = torch.zeros_like(prompt_embeds)
image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=neg_embeds,
num_inference_steps=20,
guidance_scale=5.0
).images[0]
return (
image,
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"),
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"),
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}",
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}"
)
# βββ Gradio App βββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Dual Adapter T5βCLIP") as demo:
gr.Markdown("# π§ Dual Shunt Adapter β’ SDXL Inference")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
run_btn = gr.Button("Run")
with gr.Column():
out_img = gr.Image(label="Generated Image")
delta_l = gr.Image(label="Ξ CLIP-L")
gate_l = gr.Image(label="Gate CLIP-L")
delta_g = gr.Image(label="Ξ CLIP-G")
gate_g = gr.Image(label="Gate CLIP-G")
stats_l = gr.Textbox(label="CLIP-L Stats")
stats_g = gr.Textbox(label="CLIP-G Stats")
run_btn.click(
fn=infer,
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
)
if __name__ == "__main__":
demo.launch()
|