OscarFAI commited on
Commit
804e8e0
·
1 Parent(s): de6224b

Added pad token and attention mask

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -77,17 +77,20 @@ def chat_llama3_8b(message: str,
77
  conversation.append({"role": "user", "content": message})
78
 
79
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
80
-
 
81
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
82
 
83
  generate_kwargs = dict(
84
  input_ids= input_ids,
 
85
  streamer=streamer,
86
  max_new_tokens=max_new_tokens,
87
  do_sample=True,
88
  temperature=temperature,
89
  top_p=top_p,
90
  eos_token_id=terminators,
 
91
  )
92
  # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
93
  if temperature == 0:
 
77
  conversation.append({"role": "user", "content": message})
78
 
79
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
80
+ attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
81
+
82
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
83
 
84
  generate_kwargs = dict(
85
  input_ids= input_ids,
86
+ attention_mask=attention_mask,
87
  streamer=streamer,
88
  max_new_tokens=max_new_tokens,
89
  do_sample=True,
90
  temperature=temperature,
91
  top_p=top_p,
92
  eos_token_id=terminators,
93
+ pad_token_id=tokenizer.eos_token_id,
94
  )
95
  # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
96
  if temperature == 0: