1inkusFace commited on
Commit
f2843e0
·
verified ·
1 Parent(s): bc6ba86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -267,11 +267,11 @@ def generate_30(
267
  print('averaged shape: ', prompt_embeds.shape)
268
 
269
  # 3. Concatenate the text_encoder_2 embeddings
270
- prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
 
267
  print('averaged shape: ', prompt_embeds.shape)
268
 
269
  # 3. Concatenate the text_encoder_2 embeddings
270
+ prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=1)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)