AbstractPhil commited on
Commit
db851e8
·
1 Parent(s): b012ee8
Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -107,17 +107,38 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
107
  ).input_ids.to(device)
108
 
109
  with torch.no_grad():
110
- # CLIP-L embeddings (768d) - [0] is sequence, [1] is pooled
111
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
 
114
- # CLIP-G embeddings (1280d) - [0] is sequence, [1] is pooled (reuse same prompt)
115
- clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
- neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Pooled embeddings for SDXL
119
- pooled_embeds = pipe.text_encoder_2(tokens_g)[1]
120
- neg_pooled_embeds = pipe.text_encoder_2(neg_tokens_g)[1]
121
 
122
  return {
123
  "clip_l": clip_l_embeds,
 
107
  ).input_ids.to(device)
108
 
109
  with torch.no_grad():
110
+ # CLIP-L embeddings (768d) - works fine
111
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
 
114
+ # CLIP-G embeddings (1280d) - debug the output structure
115
+ clip_g_output = pipe.text_encoder_2(tokens_g)
116
+ print(f"CLIP-G output type: {type(clip_g_output)}")
117
+ print(f"CLIP-G output length: {len(clip_g_output) if hasattr(clip_g_output, '__len__') else 'no len'}")
118
+ if hasattr(clip_g_output, '__len__') and len(clip_g_output) > 0:
119
+ print(f"CLIP-G [0] shape: {clip_g_output[0].shape}")
120
+ if len(clip_g_output) > 1:
121
+ print(f"CLIP-G [1] shape: {clip_g_output[1].shape}")
122
+
123
+ # Try different ways to get the sequence embeddings
124
+ if hasattr(clip_g_output, 'last_hidden_state'):
125
+ clip_g_embeds = clip_g_output.last_hidden_state
126
+ elif hasattr(clip_g_output, '__len__') and len(clip_g_output) > 0:
127
+ clip_g_embeds = clip_g_output[0]
128
+ else:
129
+ clip_g_embeds = clip_g_output
130
+
131
+ neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
132
+ if hasattr(neg_clip_g_output, 'last_hidden_state'):
133
+ neg_clip_g_embeds = neg_clip_g_output.last_hidden_state
134
+ elif hasattr(neg_clip_g_output, '__len__') and len(neg_clip_g_output) > 0:
135
+ neg_clip_g_embeds = neg_clip_g_output[0]
136
+ else:
137
+ neg_clip_g_embeds = neg_clip_g_output
138
 
139
  # Pooled embeddings for SDXL
140
+ pooled_embeds = clip_g_output[1] if hasattr(clip_g_output, '__len__') and len(clip_g_output) > 1 else clip_g_output.pooler_output
141
+ neg_pooled_embeds = neg_clip_g_output[1] if hasattr(neg_clip_g_output, '__len__') and len(neg_clip_g_output) > 1 else neg_clip_g_output.pooler_output
142
 
143
  return {
144
  "clip_l": clip_l_embeds,