5to9 commited on
Commit
d0aacc5
·
1 Parent(s): ee6ec78

0.14 set dtype of input_ids

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -113,12 +113,14 @@ def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_token
113
  input_ids_a = tokenizer_a.apply_chat_template(
114
  new_messages_a,
115
  add_generation_prompt=True,
 
116
  return_tensors="pt"
117
  ).to(model_a.device)
118
 
119
  input_ids_b = tokenizer_b.apply_chat_template(
120
  new_messages_b,
121
  add_generation_prompt=True,
 
122
  return_tensors="pt"
123
  ).to(model_b.device)
124
 
 
113
  input_ids_a = tokenizer_a.apply_chat_template(
114
  new_messages_a,
115
  add_generation_prompt=True,
116
+ dtype=torch.float16,
117
  return_tensors="pt"
118
  ).to(model_a.device)
119
 
120
  input_ids_b = tokenizer_b.apply_chat_template(
121
  new_messages_b,
122
  add_generation_prompt=True,
123
+ dtype=torch.float16,
124
  return_tensors="pt"
125
  ).to(model_b.device)
126