AbstractPhil commited on
Commit
70f4ed5
·
1 Parent(s): f547ef2
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -143,13 +143,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
143
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
 
145
  # Get T5 embeddings for semantic understanding
146
- t5_ids = t5_tok(
147
- prompt,
148
- return_tensors="pt",
149
- padding="max_length",
150
- max_length=77, # Match CLIP's standard length
151
- truncation=True
152
- ).input_ids.to(device)
153
  t5_seq = t5_mod(t5_ids).last_hidden_state
154
 
155
  # Get proper SDXL CLIP embeddings
@@ -161,7 +155,20 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
161
 
162
  # Apply CLIP-L adapter
163
  if adapter_l is not None:
164
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  gate_l_scaled = gate_l * gate_prob
166
  delta_l_final = delta_l * strength * gate_l_scaled
167
  clip_l_mod = clip_embeds["clip_l"] + delta_l_final
@@ -178,7 +185,20 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
178
 
179
  # Apply CLIP-G adapter
180
  if adapter_g is not None:
181
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_embeds["clip_g"])
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  gate_g_scaled = gate_g * gate_prob
183
  delta_g_final = delta_g * strength * gate_g_scaled
184
  clip_g_mod = clip_embeds["clip_g"] + delta_g_final
 
143
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
 
145
  # Get T5 embeddings for semantic understanding
146
+ t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
 
 
 
 
 
 
147
  t5_seq = t5_mod(t5_ids).last_hidden_state
148
 
149
  # Get proper SDXL CLIP embeddings
 
155
 
156
  # Apply CLIP-L adapter
157
  if adapter_l is not None:
158
+ # Ensure tensor shapes match for cross-attention
159
+ print(f"T5 seq shape: {t5_seq.shape}, CLIP-L shape: {clip_embeds['clip_l'].shape}")
160
+
161
+ # Resize T5 sequence to match CLIP sequence length if needed
162
+ if t5_seq.size(1) != clip_embeds["clip_l"].size(1):
163
+ t5_seq_resized = torch.nn.functional.interpolate(
164
+ t5_seq.transpose(1, 2),
165
+ size=clip_embeds["clip_l"].size(1),
166
+ mode="nearest"
167
+ ).transpose(1, 2)
168
+ else:
169
+ t5_seq_resized = t5_seq
170
+
171
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq_resized, clip_embeds["clip_l"])
172
  gate_l_scaled = gate_l * gate_prob
173
  delta_l_final = delta_l * strength * gate_l_scaled
174
  clip_l_mod = clip_embeds["clip_l"] + delta_l_final
 
185
 
186
  # Apply CLIP-G adapter
187
  if adapter_g is not None:
188
+ # Ensure tensor shapes match for cross-attention
189
+ print(f"T5 seq shape: {t5_seq.shape}, CLIP-G shape: {clip_embeds['clip_g'].shape}")
190
+
191
+ # Resize T5 sequence to match CLIP sequence length if needed
192
+ if t5_seq.size(1) != clip_embeds["clip_g"].size(1):
193
+ t5_seq_resized = torch.nn.functional.interpolate(
194
+ t5_seq.transpose(1, 2),
195
+ size=clip_embeds["clip_g"].size(1),
196
+ mode="nearest"
197
+ ).transpose(1, 2)
198
+ else:
199
+ t5_seq_resized = t5_seq
200
+
201
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq_resized, clip_embeds["clip_g"])
202
  gate_g_scaled = gate_g * gate_prob
203
  delta_g_final = delta_g * strength * gate_g_scaled
204
  clip_g_mod = clip_embeds["clip_g"] + delta_g_final