Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from configs import T5_SHUNT_REPOS | |
import io | |
# βββ Global Variables βββββββββββββββββββββββββββββββββββββββββ | |
t5_tok = None | |
t5_mod = None | |
pipe = None | |
# Available schedulers | |
SCHEDULERS = { | |
"DPM++ 2M": DPMSolverMultistepScheduler, | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
} | |
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ | |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] | |
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] | |
config_l = T5_SHUNT_REPOS["clip_l"]["config"] | |
config_g = T5_SHUNT_REPOS["clip_g"]["config"] | |
# βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββ | |
def load_adapter(repo, filename, config, device): | |
"""Load adapter from safetensors file""" | |
from safetensors.torch import safe_open | |
path = hf_hub_download(repo_id=repo, filename=filename) | |
model = TwoStreamShuntAdapter(config).eval() | |
tensors = {} | |
with safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
model.load_state_dict(tensors) | |
return model.to(device) | |
def plot_heat(mat, title): | |
"""Create heatmap visualization with proper shape handling""" | |
# Handle different input shapes | |
if isinstance(mat, torch.Tensor): | |
mat = mat.detach().cpu().numpy() | |
# Ensure we have a 2D array for visualization | |
if len(mat.shape) == 1: | |
# 1D array - reshape to single row | |
mat = mat.reshape(1, -1) | |
elif len(mat.shape) == 3: | |
# 3D array - average over batch dimension | |
if mat.shape[0] == 1: | |
mat = mat.squeeze(0) | |
else: | |
mat = mat.mean(axis=0) | |
elif len(mat.shape) > 3: | |
# Flatten higher dimensions | |
mat = mat.reshape(-1, mat.shape[-1]) | |
# Create figure with proper DPI | |
plt.figure(figsize=(8, 4), dpi=100) | |
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest') | |
plt.title(title, fontsize=12, fontweight='bold') | |
plt.xlabel("Token Position") | |
plt.ylabel("Feature Dimension") | |
plt.colorbar(shrink=0.8) | |
plt.tight_layout() | |
# Convert to PIL Image | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png", bbox_inches='tight', dpi=100) | |
buf.seek(0) | |
pil_image = Image.open(buf) | |
plt.close() | |
# Convert to numpy array for Gradio | |
return np.array(pil_image) | |
def encode_sdxl_prompt(pipe, prompt, negative_prompt, device): | |
"""Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders""" | |
# Tokenize for both encoders | |
tokens_l = pipe.tokenizer( | |
prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" | |
).input_ids.to(device) | |
tokens_g = pipe.tokenizer_2( | |
prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" | |
).input_ids.to(device) | |
neg_tokens_l = pipe.tokenizer( | |
negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" | |
).input_ids.to(device) | |
neg_tokens_g = pipe.tokenizer_2( | |
negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" | |
).input_ids.to(device) | |
with torch.no_grad(): | |
# CLIP-L: [0] = sequence, [1] = pooled | |
clip_l_output = pipe.text_encoder(tokens_l, output_hidden_states=False) | |
clip_l_embeds = clip_l_output[0] | |
neg_clip_l_output = pipe.text_encoder(neg_tokens_l, output_hidden_states=False) | |
neg_clip_l_embeds = neg_clip_l_output[0] | |
# CLIP-G: [0] = pooled, [1] = sequence | |
clip_g_output = pipe.text_encoder_2(tokens_g, output_hidden_states=False) | |
clip_g_embeds = clip_g_output[1] # sequence embeddings | |
pooled_embeds = clip_g_output[0] # pooled embeddings | |
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g, output_hidden_states=False) | |
neg_clip_g_embeds = neg_clip_g_output[1] | |
neg_pooled_embeds = neg_clip_g_output[0] | |
return { | |
"clip_l": clip_l_embeds, | |
"clip_g": clip_g_embeds, | |
"neg_clip_l": neg_clip_l_embeds, | |
"neg_clip_g": neg_clip_g_embeds, | |
"pooled": pooled_embeds, | |
"neg_pooled": neg_pooled_embeds | |
} | |
# βββ Main Inference Function ββββββββββββββββββββββββββββββββββ | |
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, delta_scale, | |
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, | |
scheduler_name, width, height, seed): | |
global t5_tok, t5_mod, pipe | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
# Initialize models | |
if t5_tok is None: | |
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
if pipe is None: | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
variant="fp16", | |
use_safetensors=True | |
).to(device) | |
# Set seed | |
if seed != -1: | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
else: | |
generator = None | |
# Set scheduler | |
if scheduler_name in SCHEDULERS: | |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config) | |
# Get T5 embeddings | |
t5_ids = t5_tok( | |
prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True | |
).input_ids.to(device) | |
with torch.no_grad(): | |
t5_seq = t5_mod(t5_ids).last_hidden_state | |
# Get CLIP embeddings | |
clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device) | |
# Load and apply adapters | |
if(adapter_l_file == "t5-vit-l-14-dual_shunt_booru_13_000_000.safetensors" or adapter_l_file == "t5-vit-l-14-dual_shunt_booru_51_200_000.safetensors"): | |
config_l["heads"] = 4 | |
else: | |
config_l["heads"] = 12 | |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None | |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None | |
# Apply CLIP-L adapter | |
if adapter_l is not None: | |
with torch.no_grad(): | |
# Run adapter forward pass | |
adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float()) | |
# Unpack outputs (ensure correct number of outputs) | |
if len(adapter_output) == 8: | |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output | |
else: | |
# Handle different return formats | |
anchor_l = adapter_output[0] | |
delta_l = adapter_output[1] | |
log_sigma_l = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_l) | |
gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l) | |
tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0) | |
g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0) | |
# Scale delta values | |
delta_l = delta_l * delta_scale | |
# Apply g_pred scaling to gate | |
gate_l = gate_l * g_pred_l * gpred_scale | |
# Apply gate scaling | |
gate_l_scaled = torch.sigmoid(gate_l) * gate_prob | |
# Compute final delta with strength and gate | |
delta_l_final = delta_l * strength * gate_l_scaled | |
# Apply delta to embeddings | |
clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype) | |
# Apply sigma-based noise if specified | |
if sigma_scale > 0: | |
sigma_l = torch.exp(log_sigma_l * sigma_scale) | |
clip_l_mod += torch.randn_like(clip_l_mod) * sigma_l.to(dtype) | |
# Apply anchor mixing if enabled | |
if use_anchor: | |
clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype) | |
# Add additional noise if specified | |
if noise > 0: | |
clip_l_mod += torch.randn_like(clip_l_mod) * noise | |
else: | |
clip_l_mod = clip_embeds["clip_l"] | |
delta_l_final = torch.zeros_like(clip_embeds["clip_l"]) | |
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"]) | |
g_pred_l = torch.tensor(0.0) | |
tau_l = torch.tensor(0.0) | |
# Apply CLIP-G adapter | |
if adapter_g is not None: | |
with torch.no_grad(): | |
# Run adapter forward pass | |
adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float()) | |
# Unpack outputs (ensure correct number of outputs) | |
if len(adapter_output) == 8: | |
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output | |
else: | |
# Handle different return formats | |
anchor_g = adapter_output[0] | |
delta_g = adapter_output[1] | |
log_sigma_g = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_g) | |
gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g) | |
tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0) | |
g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0) | |
# Scale delta values | |
delta_g = delta_g * delta_scale | |
# Apply g_pred scaling to gate | |
gate_g = gate_g * g_pred_g * gpred_scale | |
# Apply gate scaling | |
gate_g_scaled = torch.sigmoid(gate_g) * gate_prob | |
# Compute final delta with strength and gate | |
delta_g_final = delta_g * strength * gate_g_scaled | |
# Apply delta to embeddings | |
clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype) | |
# Apply sigma-based noise if specified | |
if sigma_scale > 0: | |
sigma_g = torch.exp(log_sigma_g * sigma_scale) | |
clip_g_mod += torch.randn_like(clip_g_mod) * sigma_g.to(dtype) | |
# Apply anchor mixing if enabled | |
if use_anchor: | |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype) | |
# Add additional noise if specified | |
if noise > 0: | |
clip_g_mod += torch.randn_like(clip_g_mod) * noise | |
else: | |
clip_g_mod = clip_embeds["clip_g"] | |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"]) | |
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"]) | |
g_pred_g = torch.tensor(0.0) | |
tau_g = torch.tensor(0.0) | |
# Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048 | |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) | |
neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1) | |
# Generate image | |
image = pipe( | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=clip_embeds["pooled"], | |
negative_prompt_embeds=neg_embeds, | |
negative_pooled_prompt_embeds=clip_embeds["neg_pooled"], | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
num_images_per_prompt=1, | |
generator=generator | |
).images[0] | |
# Create visualizations | |
delta_l_viz = plot_heat(delta_l_final.squeeze(), "CLIP-L Delta Values") | |
gate_l_viz = plot_heat(gate_l_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-L Gate Activations") | |
delta_g_viz = plot_heat(delta_g_final.squeeze(), "CLIP-G Delta Values") | |
gate_g_viz = plot_heat(gate_g_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-G Gate Activations") | |
# Statistics | |
stats_l = f"g_pred_l: {float(g_pred_l.mean().item() if hasattr(g_pred_l, 'mean') else g_pred_l):.3f}, Ο_l: {float(tau_l.mean().item() if hasattr(tau_l, 'mean') else tau_l):.3f}" | |
stats_g = f"g_pred_g: {float(g_pred_g.mean().item() if hasattr(g_pred_g, 'mean') else g_pred_g):.3f}, Ο_g: {float(tau_g.mean().item() if hasattr(tau_g, 'mean') else tau_g):.3f}" | |
return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g | |
# βββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββ | |
def create_interface(): | |
with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ SDXL Dual Shunt Adapter") | |
gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Prompts | |
gr.Markdown("### π Prompts") | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="a futuristic control station with holographic displays", | |
lines=3, | |
placeholder="Describe what you want to generate..." | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="blurry, low quality, distorted", | |
lines=2, | |
placeholder="Describe what you want to avoid..." | |
) | |
# Adapters | |
gr.Markdown("### βοΈ Adapters") | |
adapter_l = gr.Dropdown( | |
choices=["None"] + clip_l_opts, | |
label="CLIP-L (768d) Adapter", | |
value="t5-vit-l-14-dual_shunt_caption.safetensors", | |
info="Choose adapter for CLIP-L embeddings" | |
) | |
adapter_g = gr.Dropdown( | |
choices=["None"] + clip_g_opts, | |
label="CLIP-G (1280d) Adapter", | |
value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors", | |
info="Choose adapter for CLIP-G embeddings" | |
) | |
# Controls | |
gr.Markdown("### ποΈ Adapter Controls") | |
strength = gr.Slider(0.0, 10.0, value=4.0, step=0.01, label="Adapter Strength") | |
delta_scale = gr.Slider(-15.0, 15.0, value=0.2, step=0.1, label="Delta Scale", info="Scales the delta values, recommended 1") | |
sigma_scale = gr.Slider(0, 15.0, value=0.1, step=0.1, label="Sigma Scale", info="Scales the noise variance, recommended 1") | |
gpred_scale = gr.Slider(0.0, 20.0, value=2.0, step=0.01, label="G-Pred Scale", info="Scales the gate prediction, recommended 2") | |
noise = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Noise Injection") | |
gate_prob = gr.Slider(0.0, 1.0, value=0.27, step=0.01, label="Gate Probability") | |
use_anchor = gr.Checkbox(label="Use Anchor Points", value=True) | |
# Generation Settings | |
gr.Markdown("### π¨ Generation Settings") | |
with gr.Row(): | |
steps = gr.Slider(1, 50, value=20, step=1, label="Steps") | |
cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="CFG Scale") | |
scheduler_name = gr.Dropdown( | |
choices=list(SCHEDULERS.keys()), | |
value="DPM++ 2M", | |
label="Scheduler" | |
) | |
with gr.Row(): | |
width = gr.Slider(512, 1536, value=1024, step=64, label="Width") | |
height = gr.Slider(512, 1536, value=1024, step=64, label="Height") | |
seed = gr.Number(value=-1, label="Seed (-1 for random)", precision=0) | |
generate_btn = gr.Button("π Generate Image", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
# Output | |
gr.Markdown("### πΌοΈ Generated Image") | |
output_image = gr.Image(label="Result", height=400, show_label=False) | |
# Visualizations | |
gr.Markdown("### π Adapter Analysis") | |
with gr.Row(): | |
delta_l_img = gr.Image(label="CLIP-L Deltas", height=200) | |
gate_l_img = gr.Image(label="CLIP-L Gates", height=200) | |
with gr.Row(): | |
delta_g_img = gr.Image(label="CLIP-G Deltas", height=200) | |
gate_g_img = gr.Image(label="CLIP-G Gates", height=200) | |
# Statistics | |
gr.Markdown("### π Statistics") | |
stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False) | |
stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False) | |
# Event handler | |
def run_generation(*args): | |
# Process adapter selections | |
processed_args = list(args) | |
processed_args[2] = None if args[2] == "None" else args[2] # adapter_l | |
processed_args[3] = None if args[3] == "None" else args[3] # adapter_g | |
return infer(*processed_args) | |
generate_btn.click( | |
fn=run_generation, | |
inputs=[ | |
prompt, negative_prompt, adapter_l, adapter_g, strength, delta_scale, | |
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, | |
scheduler_name, width, height, seed | |
], | |
outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text] | |
) | |
return demo | |
# βββ Launch ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |