AbstractPhil commited on
Commit
c22af2e
Β·
verified Β·
1 Parent(s): 504e98b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -112
app.py CHANGED
@@ -39,14 +39,13 @@ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
39
  # ─── Loader ───────────────────────────────────────────────────
40
  from safetensors.torch import safe_open
41
 
42
- @spaces.GPU
43
  def load_adapter(repo, filename, config):
44
  # Don't initialize device here
45
  path = hf_hub_download(repo_id=repo, filename=filename)
46
 
47
  model = TwoStreamShuntAdapter(config).eval()
48
  tensors = {}
49
- with safe_open(path, framework="pt", device="cuda") as f:
50
  for key in f.keys():
51
  tensors[key] = f.get_tensor(key)
52
  model.load_state_dict(tensors)
@@ -129,131 +128,128 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
129
  "neg_pooled": neg_pooled_embeds
130
  }
131
 
 
 
 
 
132
 
133
  # ─── Inference ────────────────────────────────────────────
134
-
135
  @spaces.GPU
136
- def infer(
137
- prompt, negative_prompt, adapter_l_file, adapter_g_file,
138
- strength, noise, gate_prob, use_anchor,
139
- steps, cfg_scale, scheduler_name,
140
- width, height, seed
141
- ):
142
- import torch
143
- import numpy as np
144
-
145
  global t5_tok, t5_mod, pipe
146
  device = torch.device("cuda")
147
  dtype = torch.float16
148
-
149
- with torch.no_grad():
150
- # Initialize tokenizer and model
151
- if t5_tok is None:
152
- t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
153
- t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
154
-
155
- if pipe is None:
156
- pipe = StableDiffusionXLPipeline.from_pretrained(
157
- "stabilityai/stable-diffusion-xl-base-1.0",
158
- torch_dtype=dtype,
159
- variant="fp16",
160
- use_safetensors=True
161
- ).to(device)
162
- pipe.text_encoder = pipe.text_encoder.to(device)
163
- pipe.text_encoder_2 = pipe.text_encoder_2.to(device)
164
-
165
- # Reproducibility
166
- if seed != -1:
167
- torch.manual_seed(seed)
168
- np.random.seed(seed)
169
-
170
- # Scheduler
171
- if scheduler_name in SCHEDULERS:
172
- pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
173
-
174
- # T5 embeddings
175
- t5_ids = t5_tok(
176
- prompt, return_tensors="pt",
177
- padding="max_length", max_length=77, truncation=True
178
- ).input_ids.to(device)
179
- t5_seq = t5_mod(t5_ids).last_hidden_state
180
-
181
- # CLIP embeddings
182
- clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
183
-
184
- # Debug shapes
185
- print(f"T5 seq shape: {t5_seq.shape}")
186
- print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
187
- print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
188
-
189
- # Load adapters
190
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
191
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
192
-
193
- # ---- Adapter L ----
194
- if adapter_l:
195
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
196
- gate_l_scaled = gate_l * gate_prob
197
- delta_l_final = delta_l * strength * gate_l_scaled
198
- clip_l_mod = clip_embeds["clip_l"] + delta_l_final
199
- if use_anchor:
200
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
201
- if noise > 0:
202
- clip_l_mod += torch.randn_like(clip_l_mod) * noise
203
- else:
204
- clip_l_mod = clip_embeds["clip_l"]
205
- delta_l_final = torch.zeros_like(clip_l_mod)
206
- gate_l_scaled = torch.zeros_like(clip_l_mod)
207
- g_pred_l = torch.tensor(0.0)
208
- tau_l = torch.tensor(0.0)
209
-
210
- # ---- Adapter G ----
211
- if adapter_g:
212
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
213
- gate_g_scaled = gate_g * gate_prob
214
- delta_g_final = delta_g * strength * gate_g_scaled
215
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final
216
- if use_anchor:
217
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
218
- if noise > 0:
219
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
220
- else:
221
- clip_g_mod = clip_embeds["clip_g"]
222
- delta_g_final = torch.zeros_like(clip_g_mod)
223
- gate_g_scaled = torch.zeros_like(clip_g_mod)
224
- g_pred_g = torch.tensor(0.0)
225
- tau_g = torch.tensor(0.0)
226
-
227
- # ---- Combine embeddings ----
228
- prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
229
- neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
230
-
231
- # ---- Generate image ----
232
- generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
233
- image = pipe(
234
- prompt_embeds=prompt_embeds,
235
- pooled_prompt_embeds=clip_embeds["pooled"],
236
- negative_prompt_embeds=neg_embeds,
237
- negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
238
- num_inference_steps=steps,
239
- guidance_scale=cfg_scale,
240
- width=width,
241
- height=height,
242
- num_images_per_prompt=1,
243
- generator=generator,
244
- ).images[0]
245
-
246
  return (
247
  image,
248
  plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
249
- plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
250
  plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
251
  plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
252
  f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
253
  f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
254
  )
255
 
256
-
257
  # ─── Gradio Interface ─────────────────────────────────────────
258
  with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
259
  gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
 
39
  # ─── Loader ───────────────────────────────────────────────────
40
  from safetensors.torch import safe_open
41
 
 
42
  def load_adapter(repo, filename, config):
43
  # Don't initialize device here
44
  path = hf_hub_download(repo_id=repo, filename=filename)
45
 
46
  model = TwoStreamShuntAdapter(config).eval()
47
  tensors = {}
48
+ with safe_open(path, framework="pt", device="cpu") as f:
49
  for key in f.keys():
50
  tensors[key] = f.get_tensor(key)
51
  model.load_state_dict(tensors)
 
128
  "neg_pooled": neg_pooled_embeds
129
  }
130
 
131
+ # ─── Inference ────────────────────────────────────────────────
132
+ @torch.no_grad()
133
+ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
 
136
  # ─── Inference ────────────────────────────────────────────
 
137
  @spaces.GPU
138
+ @torch.no_grad()
139
+ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
140
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
141
+
142
+ # Initialize device and models inside GPU context
 
 
 
 
143
  global t5_tok, t5_mod, pipe
144
  device = torch.device("cuda")
145
  dtype = torch.float16
146
+
147
+ # Load models if not already loaded
148
+ if t5_tok is None:
149
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
150
+ t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
151
+
152
+ if pipe is None:
153
+ pipe = StableDiffusionXLPipeline.from_pretrained(
154
+ "stabilityai/stable-diffusion-xl-base-1.0",
155
+ torch_dtype=dtype,
156
+ variant="fp16",
157
+ use_safetensors=True
158
+ ).to(device)
159
+
160
+ # Set seed for reproducibility
161
+ if seed != -1:
162
+ torch.manual_seed(seed)
163
+ np.random.seed(seed)
164
+
165
+ # Set scheduler
166
+ if scheduler_name in SCHEDULERS:
167
+ pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
168
+
169
+ # Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
170
+ t5_ids = t5_tok(
171
+ prompt,
172
+ return_tensors="pt",
173
+ padding="max_length",
174
+ max_length=77,
175
+ truncation=True
176
+ ).input_ids.to(device)
177
+ t5_seq = t5_mod(t5_ids).last_hidden_state
178
+
179
+ # Get proper SDXL CLIP embeddings
180
+ clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
181
+
182
+ # Debug shapes
183
+ print(f"T5 seq shape: {t5_seq.shape}")
184
+ print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
185
+ print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
186
+
187
+ # Load adapters
188
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
189
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
190
+
191
+ # Apply CLIP-L adapter
192
+ if adapter_l is not None:
193
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
194
+ gate_l_scaled = gate_l * gate_prob
195
+ delta_l_final = delta_l * strength * gate_l_scaled
196
+ clip_l_mod = clip_embeds["clip_l"] + delta_l_final
197
+ if use_anchor:
198
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
199
+ if noise > 0:
200
+ clip_l_mod += torch.randn_like(clip_l_mod) * noise
201
+ else:
202
+ clip_l_mod = clip_embeds["clip_l"]
203
+ delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
204
+ gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
205
+ g_pred_l = torch.tensor(0.0)
206
+ tau_l = torch.tensor(0.0)
207
+
208
+ # Apply CLIP-G adapter
209
+ if adapter_g is not None:
210
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
211
+ gate_g_scaled = gate_g * gate_prob
212
+ delta_g_final = delta_g * strength * gate_g_scaled
213
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final
214
+ if use_anchor:
215
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
216
+ if noise > 0:
217
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
218
+ else:
219
+ clip_g_mod = clip_embeds["clip_g"]
220
+ delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
221
+ gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
222
+ g_pred_g = torch.tensor(0.0)
223
+ tau_g = torch.tensor(0.0)
224
+
225
+ # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
226
+ prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
227
+ neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
228
+
229
+ # Generate image with proper SDXL parameters
230
+ image = pipe(
231
+ prompt_embeds=prompt_embeds,
232
+ pooled_prompt_embeds=clip_embeds["pooled"],
233
+ negative_prompt_embeds=neg_embeds,
234
+ negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
235
+ num_inference_steps=steps,
236
+ guidance_scale=cfg_scale,
237
+ width=width,
238
+ height=height,
239
+ num_images_per_prompt=1, # Explicitly set this
240
+ generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
241
+ ).images[0]
242
+
 
243
  return (
244
  image,
245
  plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
246
+ plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
247
  plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
248
  plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
249
  f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
250
  f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
251
  )
252
 
 
253
  # ─── Gradio Interface ─────────────────────────────────────────
254
  with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
255
  gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")