1inkusFace commited on
Commit
4e598c3
·
verified ·
1 Parent(s): 75cd14f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -209,8 +209,8 @@ def generate_30(
209
  return_tensors="pt",
210
  )
211
  text_input_ids2 = text_inputs2.input_ids
212
- prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
213
- prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
214
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
215
 
216
  options = {
@@ -274,8 +274,8 @@ def generate_60(
274
  return_tensors="pt",
275
  )
276
  text_input_ids2 = text_inputs2.input_ids
277
- prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
278
- prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
279
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
280
 
281
  options = {
@@ -339,8 +339,8 @@ def generate_90(
339
  return_tensors="pt",
340
  )
341
  text_input_ids2 = text_inputs2.input_ids
342
- prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device, dtype=torch.bfloat16), output_hidden_states=True)
343
- prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device, dtype=torch.bfloat16), output_hidden_states=True)
344
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
345
 
346
  options = {
 
209
  return_tensors="pt",
210
  )
211
  text_input_ids2 = text_inputs2.input_ids
212
+ prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
213
+ prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
214
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
215
 
216
  options = {
 
274
  return_tensors="pt",
275
  )
276
  text_input_ids2 = text_inputs2.input_ids
277
+ prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
278
+ prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
279
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
280
 
281
  options = {
 
339
  return_tensors="pt",
340
  )
341
  text_input_ids2 = text_inputs2.input_ids
342
+ prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
343
+ prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
344
  prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
345
 
346
  options = {