1inkusFace commited on
Commit
dc2becd
·
verified ·
1 Parent(s): f663c13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -219,9 +219,8 @@ def generate_30(
219
  pooled_prompt_embeds_list.append(prompt_embedsb[0])
220
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
221
  prompt_embeds_list.append(prompt_embedsb)
222
- prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
223
- prompt_embeds = prompt_embeds.repeat(1, 1, 1)
224
- pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
225
 
226
  options = {
227
  #"prompt": prompt,
@@ -295,9 +294,8 @@ def generate_60(
295
  pooled_prompt_embeds_list.append(prompt_embedsb[0])
296
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
297
  prompt_embeds_list.append(prompt_embedsb)
298
- prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
299
- prompt_embeds = prompt_embeds.repeat(1, 1, 1)
300
- pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
301
 
302
  options = {
303
  #"prompt": prompt,
@@ -371,9 +369,8 @@ def generate_90(
371
  pooled_prompt_embeds_list.append(prompt_embedsb[0])
372
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
373
  prompt_embeds_list.append(prompt_embedsb)
374
- prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
375
- prompt_embeds = prompt_embeds.repeat(1, 1, 1)
376
- pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
377
 
378
  options = {
379
  #"prompt": prompt,
 
219
  pooled_prompt_embeds_list.append(prompt_embedsb[0])
220
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
221
  prompt_embeds_list.append(prompt_embedsb)
222
+ prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
223
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
 
224
 
225
  options = {
226
  #"prompt": prompt,
 
294
  pooled_prompt_embeds_list.append(prompt_embedsb[0])
295
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
296
  prompt_embeds_list.append(prompt_embedsb)
297
+ prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
298
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
 
299
 
300
  options = {
301
  #"prompt": prompt,
 
369
  pooled_prompt_embeds_list.append(prompt_embedsb[0])
370
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
371
  prompt_embeds_list.append(prompt_embedsb)
372
+ prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
373
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
 
374
 
375
  options = {
376
  #"prompt": prompt,