Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import spaces | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from configs import T5_SHUNT_REPOS | |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# T5 Model for semantic understanding | |
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
# SDXL Pipeline with proper text encoders | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
variant="fp16" if dtype == torch.float16 else None, | |
use_safetensors=True | |
).to(device) | |
# Available schedulers | |
SCHEDULERS = { | |
"DPM++ 2M": DPMSolverMultistepScheduler, | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
} | |
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ | |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] | |
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] | |
config_l = T5_SHUNT_REPOS["clip_l"]["config"] | |
config_g = T5_SHUNT_REPOS["clip_g"]["config"] | |
# βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from safetensors.torch import safe_open | |
def load_adapter(repo, filename, config): | |
path = hf_hub_download(repo_id=repo, filename=filename) | |
model = TwoStreamShuntAdapter(config).eval() | |
tensors = {} | |
with safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
model.load_state_dict(tensors) | |
model.to(device) | |
return model | |
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ | |
def plot_heat(mat, title): | |
import io | |
fig, ax = plt.subplots(figsize=(6, 3), dpi=100) | |
im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper") | |
ax.set_title(title) | |
plt.colorbar(im, ax=ax) | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png", bbox_inches='tight') | |
buf.seek(0) | |
plt.close(fig) | |
return buf | |
# βββ SDXL Text Encoding βββββββββββββββββββββββββββββββββββββββ | |
def encode_sdxl_prompt(prompt, negative_prompt=""): | |
"""Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders""" | |
# Tokenize for both encoders | |
tokens_l = pipe.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt" | |
).input_ids.to(device) | |
tokens_g = pipe.tokenizer_2( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt" | |
).input_ids.to(device) | |
# Negative prompts | |
neg_tokens_l = pipe.tokenizer( | |
negative_prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt" | |
).input_ids.to(device) | |
neg_tokens_g = pipe.tokenizer_2( | |
negative_prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt" | |
).input_ids.to(device) | |
with torch.no_grad(): | |
# CLIP-L embeddings (768d) - works fine | |
clip_l_embeds = pipe.text_encoder(tokens_l)[0] | |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0] | |
# CLIP-G embeddings (1280d) - [0] is pooled, [1] is sequence (opposite of CLIP-L) | |
clip_g_output = pipe.text_encoder_2(tokens_g) | |
clip_g_embeds = clip_g_output[1] # sequence embeddings | |
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g) | |
neg_clip_g_embeds = neg_clip_g_output[1] # sequence embeddings | |
# Pooled embeddings for SDXL | |
pooled_embeds = clip_g_output[0] # pooled embeddings | |
neg_pooled_embeds = neg_clip_g_output[0] # pooled embeddings | |
return { | |
"clip_l": clip_l_embeds, | |
"clip_g": clip_g_embeds, | |
"neg_clip_l": neg_clip_l_embeds, | |
"neg_clip_g": neg_clip_g_embeds, | |
"pooled": pooled_embeds, | |
"neg_pooled": neg_pooled_embeds | |
} | |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, | |
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed): | |
# Set seed for reproducibility | |
if seed != -1: | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
# Set scheduler | |
if scheduler_name in SCHEDULERS: | |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) | |
# Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP | |
t5_ids = t5_tok( | |
prompt, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=77, | |
truncation=True | |
).input_ids.to(device) | |
t5_seq = t5_mod(t5_ids).last_hidden_state | |
# Get proper SDXL CLIP embeddings | |
clip_embeds = encode_sdxl_prompt(prompt, negative_prompt) | |
# Debug shapes | |
print(f"T5 seq shape: {t5_seq.shape}") | |
print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}") | |
print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}") | |
# Load adapters | |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None | |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None | |
# Apply CLIP-L adapter | |
if adapter_l is not None: | |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"]) | |
gate_l_scaled = gate_l * gate_prob | |
delta_l_final = delta_l * strength * gate_l_scaled | |
clip_l_mod = clip_embeds["clip_l"] + delta_l_final | |
if use_anchor: | |
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled | |
if noise > 0: | |
clip_l_mod += torch.randn_like(clip_l_mod) * noise | |
else: | |
clip_l_mod = clip_embeds["clip_l"] | |
delta_l_final = torch.zeros_like(clip_embeds["clip_l"]) | |
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"]) | |
g_pred_l = torch.tensor(0.0) | |
tau_l = torch.tensor(0.0) | |
# Apply CLIP-G adapter | |
if adapter_g is not None: | |
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"]) | |
gate_g_scaled = gate_g * gate_prob | |
delta_g_final = delta_g * strength * gate_g_scaled | |
clip_g_mod = clip_embeds["clip_g"] + delta_g_final | |
if use_anchor: | |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled | |
if noise > 0: | |
clip_g_mod += torch.randn_like(clip_g_mod) * noise | |
else: | |
clip_g_mod = clip_embeds["clip_g"] | |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"]) | |
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"]) | |
g_pred_g = torch.tensor(0.0) | |
tau_g = torch.tensor(0.0) | |
# Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048 | |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype) | |
neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype) | |
# Generate image with proper SDXL parameters | |
image = pipe( | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=clip_embeds["pooled"], | |
negative_prompt_embeds=neg_embeds, | |
negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
num_images_per_prompt=1, # Explicitly set this | |
generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None | |
).images[0] | |
return ( | |
image, | |
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"), | |
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"), | |
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"), | |
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"), | |
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}", | |
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}" | |
) | |
# βββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ SDXL Dual Shunt Adapter β’ T5βCLIP Enhancement") | |
gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Prompts | |
with gr.Group(): | |
gr.Markdown("### Prompts") | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="a futuristic control station with holographic displays", | |
lines=3 | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="blurry, low quality, distorted", | |
lines=2 | |
) | |
# Adapters | |
with gr.Group(): | |
gr.Markdown("### Adapters") | |
adapter_l = gr.Dropdown( | |
choices=["None"] + clip_l_opts, | |
label="CLIP-L (768d) Adapter", | |
value="None" | |
) | |
adapter_g = gr.Dropdown( | |
choices=["None"] + clip_g_opts, | |
label="CLIP-G (1280d) Adapter", | |
value="None" | |
) | |
# Adapter Controls | |
with gr.Group(): | |
gr.Markdown("### Adapter Controls") | |
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") | |
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") | |
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") | |
use_anchor = gr.Checkbox(label="Use Anchor", value=True) | |
# Generation Settings | |
with gr.Group(): | |
gr.Markdown("### Generation Settings") | |
with gr.Row(): | |
steps = gr.Slider(1, 100, value=25, step=1, label="Steps") | |
cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale") | |
scheduler_name = gr.Dropdown( | |
choices=list(SCHEDULERS.keys()), | |
value="DPM++ 2M", | |
label="Scheduler" | |
) | |
with gr.Row(): | |
width = gr.Slider(512, 1536, value=1024, step=64, label="Width") | |
height = gr.Slider(512, 1536, value=1024, step=64, label="Height") | |
seed = gr.Number(value=-1, label="Seed (-1 for random)") | |
run_btn = gr.Button("π Generate", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
# Output | |
with gr.Group(): | |
gr.Markdown("### Generated Image") | |
out_img = gr.Image(label="Result", height=400) | |
# Visualizations | |
with gr.Group(): | |
gr.Markdown("### Adapter Visualizations") | |
with gr.Row(): | |
delta_l = gr.Image(label="Ξ CLIP-L", height=200) | |
gate_l = gr.Image(label="Gate CLIP-L", height=200) | |
with gr.Row(): | |
delta_g = gr.Image(label="Ξ CLIP-G", height=200) | |
gate_g = gr.Image(label="Gate CLIP-G", height=200) | |
# Stats | |
with gr.Group(): | |
gr.Markdown("### Adapter Statistics") | |
stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False) | |
stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False) | |
# Event handlers | |
def process_adapters(adapter_l_val, adapter_g_val): | |
# Convert "None" back to None for processing | |
adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val | |
adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val | |
return adapter_l_processed, adapter_g_processed | |
def run_inference(*args): | |
# Process adapter selections | |
adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3]) | |
# Call inference with processed adapters | |
new_args = list(args) | |
new_args[2] = adapter_l_processed | |
new_args[3] = adapter_g_processed | |
return infer(*new_args) | |
run_btn.click( | |
fn=run_inference, | |
inputs=[ | |
prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob, | |
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed | |
], | |
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] | |
) | |
if __name__ == "__main__": | |
demo.launch() |