Padding
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import spaces
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
5 |
from threading import Thread
|
|
|
6 |
|
7 |
# Set an environment variable
|
8 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
@@ -43,6 +44,10 @@ h1 {
|
|
43 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
|
44 |
model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
|
45 |
|
|
|
|
|
|
|
|
|
46 |
terminators = [
|
47 |
tokenizer.eos_token_id,
|
48 |
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
@@ -88,17 +93,23 @@ def chat_mistral(message: str,
|
|
88 |
first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
|
89 |
conversation.append({"role": "user", "content": first_message})
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
94 |
|
95 |
generate_kwargs = dict(
|
96 |
input_ids=input_ids,
|
|
|
97 |
streamer=streamer,
|
98 |
max_new_tokens=max_new_tokens,
|
99 |
do_sample=True,
|
100 |
temperature=temperature,
|
101 |
top_p=top_p,
|
|
|
102 |
eos_token_id=terminators,
|
103 |
)
|
104 |
|
@@ -139,10 +150,13 @@ with gr.Blocks(fill_height=True, css=css) as demo:
|
|
139 |
gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
|
140 |
],
|
141 |
examples=[
|
142 |
-
['
|
|
|
|
|
|
|
|
|
143 |
],
|
144 |
-
cache_examples=False
|
145 |
-
type='messages',
|
146 |
)
|
147 |
|
148 |
if __name__ == "__main__":
|
|
|
3 |
import spaces
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
5 |
from threading import Thread
|
6 |
+
import torch
|
7 |
|
8 |
# Set an environment variable
|
9 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
44 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
|
45 |
model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
|
46 |
|
47 |
+
# Ensure we have a pad token
|
48 |
+
if tokenizer.pad_token_id is None:
|
49 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
50 |
+
|
51 |
terminators = [
|
52 |
tokenizer.eos_token_id,
|
53 |
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
|
|
93 |
first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
|
94 |
conversation.append({"role": "user", "content": first_message})
|
95 |
|
96 |
+
# Tokenize with padding and attention mask
|
97 |
+
input_data = tokenizer.apply_chat_template(conversation, return_tensors="pt", padding=True, truncation=True)
|
98 |
+
input_ids = input_data.to(model.device)
|
99 |
+
|
100 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id).to(dtype=torch.long, device=model.device)
|
101 |
+
|
102 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
103 |
|
104 |
generate_kwargs = dict(
|
105 |
input_ids=input_ids,
|
106 |
+
attention_mask=attention_mask, # Fixes the warning
|
107 |
streamer=streamer,
|
108 |
max_new_tokens=max_new_tokens,
|
109 |
do_sample=True,
|
110 |
temperature=temperature,
|
111 |
top_p=top_p,
|
112 |
+
pad_token_id=tokenizer.pad_token_id, # Explicitly set
|
113 |
eos_token_id=terminators,
|
114 |
)
|
115 |
|
|
|
150 |
gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
|
151 |
],
|
152 |
examples=[
|
153 |
+
['How to setup a human base on Mars? Give short answer.'],
|
154 |
+
['Explain theory of relativity to me like I’m 8 years old.'],
|
155 |
+
['What is 9,000 * 9,000?'],
|
156 |
+
['Write a pun-filled happy birthday message to my friend Alex.'],
|
157 |
+
['Justify why a penguin might make a good king of the jungle.']
|
158 |
],
|
159 |
+
cache_examples=False
|
|
|
160 |
)
|
161 |
|
162 |
if __name__ == "__main__":
|