AbstractPhil commited on
Commit
cae6d82
·
1 Parent(s): acd9841
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -111,7 +111,7 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
111
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
 
114
- # CLIP-G embeddings (1280d)
115
  clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
  neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
 
@@ -143,14 +143,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
143
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
 
145
  # Get T5 embeddings for semantic understanding
146
- t5_ids = t5_tok(
147
- prompt,
148
- return_tensors="pt",
149
- padding="max_length",
150
- max_length=77, # Match CLIP's standard length
151
- truncation=True
152
- ).input_ids.to(device)
153
- print(t5_ids.shape)
154
  t5_seq = t5_mod(t5_ids).last_hidden_state
155
 
156
  # Get proper SDXL CLIP embeddings
@@ -160,6 +153,19 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
160
  adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
161
  adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Apply CLIP-L adapter
164
  if adapter_l is not None:
165
  anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
@@ -187,6 +193,23 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
187
  clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
188
  if noise > 0:
189
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  else:
191
  clip_g_mod = clip_embeds["clip_g"]
192
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
 
111
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
 
114
+ # CLIP-G embeddings (1280d) - get the hidden states [0], not pooled [1]
115
  clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
  neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
 
 
143
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
 
145
  # Get T5 embeddings for semantic understanding
146
+ t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
 
 
 
 
 
 
 
147
  t5_seq = t5_mod(t5_ids).last_hidden_state
148
 
149
  # Get proper SDXL CLIP embeddings
 
153
  adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
154
  adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
155
 
156
+ # Ensure all embeddings have the same sequence length (77 tokens)
157
+ seq_len = 77
158
+
159
+ # Resize T5 to match CLIP sequence length
160
+ if t5_seq.size(1) != seq_len:
161
+ t5_seq = torch.nn.functional.interpolate(
162
+ t5_seq.transpose(1, 2),
163
+ size=seq_len,
164
+ mode="nearest"
165
+ ).transpose(1, 2)
166
+
167
+ print(f"After resize - T5: {t5_seq.shape}, CLIP-L: {clip_embeds['clip_l'].shape}, CLIP-G: {clip_embeds['clip_g'].shape}")
168
+
169
  # Apply CLIP-L adapter
170
  if adapter_l is not None:
171
  anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
 
193
  clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
194
  if noise > 0:
195
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
196
+ else:
197
+ clip_g_mod = clip_embeds["clip_g"]
198
+ delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
199
+ gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
200
+ g_pred_g = torch.tensor(0.0)
201
+ tau_g = torch.tensor(0.0) 2)
202
+ else:
203
+ t5_seq_resized = t5_seq
204
+
205
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq_resized, clip_embeds["clip_g"])
206
+ gate_g_scaled = gate_g * gate_prob
207
+ delta_g_final = delta_g * strength * gate_g_scaled
208
+ clip_g_mod = clip_embeds["clip_g"] + delta_g_final
209
+ if use_anchor:
210
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
211
+ if noise > 0:
212
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
213
  else:
214
  clip_g_mod = clip_embeds["clip_g"]
215
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])