1inkusFace commited on
Commit
91ac098
·
verified ·
1 Parent(s): 152d8fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -226,6 +226,7 @@ def generate_30(
226
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
227
 
228
  # 4. (Optional) Average the pooled embeddings
 
229
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
230
 
231
 
@@ -308,6 +309,7 @@ def generate_60(
308
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
309
 
310
  # 4. (Optional) Average the pooled embeddings
 
311
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
312
 
313
 
@@ -390,6 +392,7 @@ def generate_90(
390
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
391
 
392
  # 4. (Optional) Average the pooled embeddings
 
393
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
394
 
395
 
 
226
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
227
 
228
  # 4. (Optional) Average the pooled embeddings
229
+ prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
230
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
231
 
232
 
 
309
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
310
 
311
  # 4. (Optional) Average the pooled embeddings
312
+ prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
313
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
314
 
315
 
 
392
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
393
 
394
  # 4. (Optional) Average the pooled embeddings
395
+ prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
396
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
397
 
398