AbstractPhil commited on
Commit
d3c4f78
Β·
1 Parent(s): dfcfa0d
__pycache__/two_stream_shunt_adapter.cpython-310.pyc CHANGED
Binary files a/__pycache__/two_stream_shunt_adapter.cpython-310.pyc and b/__pycache__/two_stream_shunt_adapter.cpython-310.pyc differ
 
app.py CHANGED
@@ -131,123 +131,126 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
131
 
132
  # ─── Inference ────────────────────────────────────────────
133
 
134
-
135
- @torch.no_grad()
136
  @spaces.GPU
137
- def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
138
- use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
139
-
140
- # Initialize device and models inside GPU context
 
 
 
 
 
141
  global t5_tok, t5_mod, pipe
142
  device = torch.device("cuda")
143
  dtype = torch.float16
144
-
145
- # Load models if not already loaded
146
- if t5_tok is None:
147
- t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
148
- t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
149
-
150
- if pipe is None:
151
- pipe = StableDiffusionXLPipeline.from_pretrained(
152
- "stabilityai/stable-diffusion-xl-base-1.0",
153
- torch_dtype=dtype,
154
- variant="fp16",
155
- use_safetensors=True
156
- ).to(device)
157
-
158
- # Set seed for reproducibility
159
- if seed != -1:
160
- torch.manual_seed(seed)
161
- np.random.seed(seed)
162
-
163
- # Set scheduler
164
- if scheduler_name in SCHEDULERS:
165
- pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
166
-
167
- # Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
168
- t5_ids = t5_tok(
169
- prompt,
170
- return_tensors="pt",
171
- padding="max_length",
172
- max_length=77,
173
- truncation=True
174
- ).input_ids.to(device)
175
- t5_seq = t5_mod(t5_ids).last_hidden_state
176
-
177
- # Get proper SDXL CLIP embeddings
178
- clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
179
-
180
- # Debug shapes
181
- print(f"T5 seq shape: {t5_seq.shape}")
182
- print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
183
- print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
184
-
185
- # Load adapters
186
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
187
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
188
-
189
- # Apply CLIP-L adapter
190
- if adapter_l is not None:
191
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
192
- gate_l_scaled = gate_l * gate_prob
193
- delta_l_final = delta_l * strength * gate_l_scaled
194
- clip_l_mod = clip_embeds["clip_l"] + delta_l_final
195
- if use_anchor:
196
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
197
- if noise > 0:
198
- clip_l_mod += torch.randn_like(clip_l_mod) * noise
199
- else:
200
- clip_l_mod = clip_embeds["clip_l"]
201
- delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
202
- gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
203
- g_pred_l = torch.tensor(0.0)
204
- tau_l = torch.tensor(0.0)
205
-
206
- # Apply CLIP-G adapter
207
- if adapter_g is not None:
208
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
209
- gate_g_scaled = gate_g * gate_prob
210
- delta_g_final = delta_g * strength * gate_g_scaled
211
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final
212
- if use_anchor:
213
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
214
- if noise > 0:
215
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
216
- else:
217
- clip_g_mod = clip_embeds["clip_g"]
218
- delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
219
- gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
220
- g_pred_g = torch.tensor(0.0)
221
- tau_g = torch.tensor(0.0)
222
-
223
- # Combine embeddings in SDXL format: [CLIP-L(768) + CLIP-G(1280)] = 2048
224
- prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
225
- neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
226
-
227
- # Generate image with proper SDXL parameters
228
- image = pipe(
229
- prompt_embeds=prompt_embeds,
230
- pooled_prompt_embeds=clip_embeds["pooled"],
231
- negative_prompt_embeds=neg_embeds,
232
- negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
233
- num_inference_steps=steps,
234
- guidance_scale=cfg_scale,
235
- width=width,
236
- height=height,
237
- num_images_per_prompt=1, # Explicitly set this
238
- generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
239
- ).images[0]
240
-
241
  return (
242
  image,
243
  plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
244
- plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
245
  plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
246
  plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
247
  f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
248
  f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
249
  )
250
 
 
251
  # ─── Gradio Interface ─────────────────────────────────────────
252
  with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
253
  gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")
 
131
 
132
  # ─── Inference ────────────────────────────────────────────
133
 
 
 
134
  @spaces.GPU
135
+ def infer(
136
+ prompt, negative_prompt, adapter_l_file, adapter_g_file,
137
+ strength, noise, gate_prob, use_anchor,
138
+ steps, cfg_scale, scheduler_name,
139
+ width, height, seed
140
+ ):
141
+ import torch
142
+ import numpy as np
143
+
144
  global t5_tok, t5_mod, pipe
145
  device = torch.device("cuda")
146
  dtype = torch.float16
147
+
148
+ with torch.no_grad():
149
+ # Initialize tokenizer and model
150
+ if t5_tok is None:
151
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
152
+ t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
153
+
154
+ if pipe is None:
155
+ pipe = StableDiffusionXLPipeline.from_pretrained(
156
+ "stabilityai/stable-diffusion-xl-base-1.0",
157
+ torch_dtype=dtype,
158
+ variant="fp16",
159
+ use_safetensors=True
160
+ ).to(device)
161
+
162
+ # Reproducibility
163
+ if seed != -1:
164
+ torch.manual_seed(seed)
165
+ np.random.seed(seed)
166
+
167
+ # Scheduler
168
+ if scheduler_name in SCHEDULERS:
169
+ pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
170
+
171
+ # T5 embeddings
172
+ t5_ids = t5_tok(
173
+ prompt, return_tensors="pt",
174
+ padding="max_length", max_length=77, truncation=True
175
+ ).input_ids.to(device)
176
+ t5_seq = t5_mod(t5_ids).last_hidden_state
177
+
178
+ # CLIP embeddings
179
+ clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
180
+
181
+ # Debug shapes
182
+ print(f"T5 seq shape: {t5_seq.shape}")
183
+ print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
184
+ print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
185
+
186
+ # Load adapters
187
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
188
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
189
+
190
+ # ---- Adapter L ----
191
+ if adapter_l:
192
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
193
+ gate_l_scaled = gate_l * gate_prob
194
+ delta_l_final = delta_l * strength * gate_l_scaled
195
+ clip_l_mod = clip_embeds["clip_l"] + delta_l_final
196
+ if use_anchor:
197
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
198
+ if noise > 0:
199
+ clip_l_mod += torch.randn_like(clip_l_mod) * noise
200
+ else:
201
+ clip_l_mod = clip_embeds["clip_l"]
202
+ delta_l_final = torch.zeros_like(clip_l_mod)
203
+ gate_l_scaled = torch.zeros_like(clip_l_mod)
204
+ g_pred_l = torch.tensor(0.0)
205
+ tau_l = torch.tensor(0.0)
206
+
207
+ # ---- Adapter G ----
208
+ if adapter_g:
209
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
210
+ gate_g_scaled = gate_g * gate_prob
211
+ delta_g_final = delta_g * strength * gate_g_scaled
212
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final
213
+ if use_anchor:
214
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
215
+ if noise > 0:
216
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
217
+ else:
218
+ clip_g_mod = clip_embeds["clip_g"]
219
+ delta_g_final = torch.zeros_like(clip_g_mod)
220
+ gate_g_scaled = torch.zeros_like(clip_g_mod)
221
+ g_pred_g = torch.tensor(0.0)
222
+ tau_g = torch.tensor(0.0)
223
+
224
+ # ---- Combine embeddings ----
225
+ prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
226
+ neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1).to(dtype)
227
+
228
+ # ---- Generate image ----
229
+ generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
230
+ image = pipe(
231
+ prompt_embeds=prompt_embeds,
232
+ pooled_prompt_embeds=clip_embeds["pooled"],
233
+ negative_prompt_embeds=neg_embeds,
234
+ negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
235
+ num_inference_steps=steps,
236
+ guidance_scale=cfg_scale,
237
+ width=width,
238
+ height=height,
239
+ num_images_per_prompt=1,
240
+ generator=generator,
241
+ ).images[0]
242
+
 
243
  return (
244
  image,
245
  plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
246
+ plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
247
  plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
248
  plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
249
  f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
250
  f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
251
  )
252
 
253
+
254
  # ─── Gradio Interface ─────────────────────────────────────────
255
  with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
256
  gr.Markdown("# 🧠 SDXL Dual Shunt Adapter β€’ T5β†’CLIP Enhancement")