1inkusFace commited on
Commit
1c20b4c
·
verified ·
1 Parent(s): 4e598c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -210,7 +210,9 @@ def generate_30(
210
  )
211
  text_input_ids2 = text_inputs2.input_ids
212
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
213
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
 
214
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
215
 
216
  options = {
@@ -275,7 +277,9 @@ def generate_60(
275
  )
276
  text_input_ids2 = text_inputs2.input_ids
277
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
278
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
 
279
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
280
 
281
  options = {
@@ -340,7 +344,9 @@ def generate_90(
340
  )
341
  text_input_ids2 = text_inputs2.input_ids
342
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
343
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
 
344
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
345
 
346
  options = {
 
210
  )
211
  text_input_ids2 = text_inputs2.input_ids
212
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
213
+ prompt_embedsa = prompt_embedsa.hidden_states[-2]
214
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
215
+ prompt_embedsb = prompt_embedsb.hidden_states[-2]
216
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
217
 
218
  options = {
 
277
  )
278
  text_input_ids2 = text_inputs2.input_ids
279
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
280
+ prompt_embedsa = prompt_embedsa.hidden_states[-2]
281
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
282
+ prompt_embedsb = prompt_embedsb.hidden_states[-2]
283
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
284
 
285
  options = {
 
344
  )
345
  text_input_ids2 = text_inputs2.input_ids
346
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
347
+ prompt_embedsa = prompt_embedsa.hidden_states[-2]
348
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
349
+ prompt_embedsb = prompt_embedsb.hidden_states[-2]
350
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
351
 
352
  options = {