AbstractPhil's picture
yes
12aa86c
raw
history blame
14.3 kB
import spaces
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import spaces
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from configs import T5_SHUNT_REPOS
# ─── Device & Model Setup ─────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# T5 Model for semantic understanding
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
# SDXL Pipeline with proper text encoders
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None,
use_safetensors=True
).to(device)
# Available schedulers
SCHEDULERS = {
"DPM++ 2M": DPMSolverMultistepScheduler,
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
}
# ─── Adapter Configs ──────────────────────────────────────────
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
# ─── Loader ───────────────────────────────────────────────────
from safetensors.torch import safe_open
def load_adapter(repo, filename, config):
path = hf_hub_download(repo_id=repo, filename=filename)
model = TwoStreamShuntAdapter(config).eval()
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
model.load_state_dict(tensors)
model.to(device)
return model
# ─── Visualization ────────────────────────────────────────────
def plot_heat(mat, title):
import io
fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
ax.set_title(title)
plt.colorbar(im, ax=ax)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches='tight')
buf.seek(0)
plt.close(fig)
return buf
# ─── SDXL Text Encoding ───────────────────────────────────────
def encode_sdxl_prompt(prompt, negative_prompt=""):
"""Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
# Tokenize for both encoders
tokens_l = pipe.tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
tokens_g = pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
# Negative prompts
neg_tokens_l = pipe.tokenizer(
negative_prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
neg_tokens_g = pipe.tokenizer_2(
negative_prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
with torch.no_grad():
# CLIP-L embeddings (768d) - works fine
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
# CLIP-G embeddings (1280d) - [0] is pooled, [1] is sequence (opposite of CLIP-L)
clip_g_output = pipe.text_encoder_2(tokens_g)
clip_g_embeds = clip_g_output[1] # sequence embeddings
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
neg_clip_g_embeds = neg_clip_g_output[1] # sequence embeddings
# Pooled embeddings for SDXL
pooled_embeds = clip_g_output[0] # pooled embeddings
neg_pooled_embeds = neg_clip_g_output[0] # pooled embeddings
return {
"clip_l": clip_l_embeds,
"clip_g": clip_g_embeds,
"neg_clip_l": neg_clip_l_embeds,
"neg_clip_g": neg_clip_g_embeds,
"pooled": pooled_embeds,
"neg_pooled": neg_pooled_embeds
}
# ─── Inference ────────────────────────────────────────────────
@spaces.GPU
@torch.no_grad()
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
# Set seed for reproducibility
if seed != -1:
torch.manual_seed(seed)
np.random.seed(seed)
# Set scheduler
if scheduler_name in SCHEDULERS:
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
# Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
t5_ids = t5_tok(
prompt,
return_tensors="pt",
padding="max_length",
max_length=77,
truncation=True
).input_ids.to(device)
t5_seq = t5_mod(t5_ids).last_hidden_state
# Get proper SDXL CLIP embeddings
clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
# Debug shapes
print(f"T5 seq shape: {t5_seq.shape}")
print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
# Load adapters
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
# Apply CLIP-L adapter
if adapter_l is not None:
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
gate_l_scaled = gate_l * gate_prob
delta_l_final = delta_l * strength * gate_l_scaled
clip_l_mod = clip_embeds["clip_l"] + delta_l_final
if use_anchor:
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
if noise > 0:
clip_l_mod += torch.randn_like(clip_l_mod) * noise
else:
clip_l_mod = clip_embeds["clip_l"]
delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
g_pred_l = torch.tensor(0.0)
tau_l = torch.tensor(0.0)
# Apply CLIP-G adapter
if adapter_g is not None:
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
gate_g_scaled = gate_g * gate_prob
delta_g_final = delta_g * strength * gate_g_scaled
clip_g_mod = clip_embeds["clip_g"] + delta_g_final
if use_anchor:
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
if noise > 0:
clip_g_mod += torch.randn_like(clip_g_mod) * noise
else:
clip_g_mod = clip_embeds["clip_g"]
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
g_pred_g = torch.tensor(0.0)
tau_g = torch.tensor(0.0)
# Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
# Generate image with proper SDXL parameters
image = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=clip_embeds["pooled"],
negative_prompt_embeds=neg_embeds,
negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
num_images_per_prompt=1, # Explicitly set this
generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
).images[0]
return (
image,
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
)
# ─── Gradio Interface ─────────────────────────────────────────
with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
gr.Markdown("Enhance SDXL generation by using T5 semantic understanding to modify CLIP embeddings")
with gr.Row():
with gr.Column(scale=1):
# Prompts
with gr.Group():
gr.Markdown("### Prompts")
prompt = gr.Textbox(
label="Prompt",
value="a futuristic control station with holographic displays",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="blurry, low quality, distorted",
lines=2
)
# Adapters
with gr.Group():
gr.Markdown("### Adapters")
adapter_l = gr.Dropdown(
choices=["None"] + clip_l_opts,
label="CLIP-L (768d) Adapter",
value="None"
)
adapter_g = gr.Dropdown(
choices=["None"] + clip_g_opts,
label="CLIP-G (1280d) Adapter",
value="None"
)
# Adapter Controls
with gr.Group():
gr.Markdown("### Adapter Controls")
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
# Generation Settings
with gr.Group():
gr.Markdown("### Generation Settings")
with gr.Row():
steps = gr.Slider(1, 100, value=25, step=1, label="Steps")
cfg_scale = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="CFG Scale")
scheduler_name = gr.Dropdown(
choices=list(SCHEDULERS.keys()),
value="DPM++ 2M",
label="Scheduler"
)
with gr.Row():
width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
seed = gr.Number(value=-1, label="Seed (-1 for random)")
run_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
with gr.Column(scale=1):
# Output
with gr.Group():
gr.Markdown("### Generated Image")
out_img = gr.Image(label="Result", height=400)
# Visualizations
with gr.Group():
gr.Markdown("### Adapter Visualizations")
with gr.Row():
delta_l = gr.Image(label="Ξ” CLIP-L", height=200)
gate_l = gr.Image(label="Gate CLIP-L", height=200)
with gr.Row():
delta_g = gr.Image(label="Ξ” CLIP-G", height=200)
gate_g = gr.Image(label="Gate CLIP-G", height=200)
# Stats
with gr.Group():
gr.Markdown("### Adapter Statistics")
stats_l = gr.Textbox(label="CLIP-L Stats", interactive=False)
stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
# Event handlers
def process_adapters(adapter_l_val, adapter_g_val):
# Convert "None" back to None for processing
adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
return adapter_l_processed, adapter_g_processed
def run_inference(*args):
# Process adapter selections
adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
# Call inference with processed adapters
new_args = list(args)
new_args[2] = adapter_l_processed
new_args[3] = adapter_g_processed
return infer(*new_args)
run_btn.click(
fn=run_inference,
inputs=[
prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
],
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
)
if __name__ == "__main__":
demo.launch()