1inkusFace commited on
Commit
9f1950c
·
verified ·
1 Parent(s): 7b87899

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -271,7 +271,7 @@ def generate_30(
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
 
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)