Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -76,12 +76,14 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
|
|
76 |
)
|
77 |
|
78 |
print(f"Text Encoder Device: {text_encoder.device}")
|
79 |
-
text_input_ids = text_inputs.input_ids
|
80 |
-
prompt_masks = text_inputs.attention_mask
|
|
|
|
|
81 |
|
82 |
prompt_embeds = text_encoder(
|
83 |
-
input_ids=text_input_ids
|
84 |
-
attention_mask=prompt_masks
|
85 |
output_hidden_states=True,
|
86 |
).hidden_states[-2]
|
87 |
|
|
|
76 |
)
|
77 |
|
78 |
print(f"Text Encoder Device: {text_encoder.device}")
|
79 |
+
text_input_ids = text_inputs.input_ids.cuda()
|
80 |
+
prompt_masks = text_inputs.attention_mask.cuda()
|
81 |
+
print(f"Text Input Ids Device: {text_input_ids.device}")
|
82 |
+
print(f"Prompt Masks Device: {prompt_masks.device}")
|
83 |
|
84 |
prompt_embeds = text_encoder(
|
85 |
+
input_ids=text_input_ids,
|
86 |
+
attention_mask=prompt_masks,
|
87 |
output_hidden_states=True,
|
88 |
).hidden_states[-2]
|
89 |
|