Added pad token and attention mask
Browse files
app.py
CHANGED
@@ -77,17 +77,20 @@ def chat_llama3_8b(message: str,
|
|
77 |
conversation.append({"role": "user", "content": message})
|
78 |
|
79 |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
80 |
-
|
|
|
81 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
82 |
|
83 |
generate_kwargs = dict(
|
84 |
input_ids= input_ids,
|
|
|
85 |
streamer=streamer,
|
86 |
max_new_tokens=max_new_tokens,
|
87 |
do_sample=True,
|
88 |
temperature=temperature,
|
89 |
top_p=top_p,
|
90 |
eos_token_id=terminators,
|
|
|
91 |
)
|
92 |
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
93 |
if temperature == 0:
|
|
|
77 |
conversation.append({"role": "user", "content": message})
|
78 |
|
79 |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
80 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
|
81 |
+
|
82 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
83 |
|
84 |
generate_kwargs = dict(
|
85 |
input_ids= input_ids,
|
86 |
+
attention_mask=attention_mask,
|
87 |
streamer=streamer,
|
88 |
max_new_tokens=max_new_tokens,
|
89 |
do_sample=True,
|
90 |
temperature=temperature,
|
91 |
top_p=top_p,
|
92 |
eos_token_id=terminators,
|
93 |
+
pad_token_id=tokenizer.eos_token_id,
|
94 |
)
|
95 |
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
96 |
if temperature == 0:
|