Dakerqi commited on
Commit
8ed7188
·
verified ·
1 Parent(s): 4e87cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -76,12 +76,14 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
76
  )
77
 
78
  print(f"Text Encoder Device: {text_encoder.device}")
79
- text_input_ids = text_inputs.input_ids
80
- prompt_masks = text_inputs.attention_mask
 
 
81
 
82
  prompt_embeds = text_encoder(
83
- input_ids=text_input_ids.cuda(),
84
- attention_mask=prompt_masks.cuda(),
85
  output_hidden_states=True,
86
  ).hidden_states[-2]
87
 
 
76
  )
77
 
78
  print(f"Text Encoder Device: {text_encoder.device}")
79
+ text_input_ids = text_inputs.input_ids.cuda()
80
+ prompt_masks = text_inputs.attention_mask.cuda()
81
+ print(f"Text Input Ids Device: {text_input_ids.device}")
82
+ print(f"Prompt Masks Device: {prompt_masks.device}")
83
 
84
  prompt_embeds = text_encoder(
85
+ input_ids=text_input_ids,
86
+ attention_mask=prompt_masks,
87
  output_hidden_states=True,
88
  ).hidden_states[-2]
89