OscarFAI commited on
Commit
f1d7efb
·
1 Parent(s): 9577ec2
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
 
6
 
7
  # Set an environment variable
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -43,6 +44,10 @@ h1 {
43
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
44
  model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
45
 
 
 
 
 
46
  terminators = [
47
  tokenizer.eos_token_id,
48
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
@@ -88,17 +93,23 @@ def chat_mistral(message: str,
88
  first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
89
  conversation.append({"role": "user", "content": first_message})
90
 
91
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
92
-
 
 
 
 
93
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
94
 
95
  generate_kwargs = dict(
96
  input_ids=input_ids,
 
97
  streamer=streamer,
98
  max_new_tokens=max_new_tokens,
99
  do_sample=True,
100
  temperature=temperature,
101
  top_p=top_p,
 
102
  eos_token_id=terminators,
103
  )
104
 
@@ -139,10 +150,13 @@ with gr.Blocks(fill_height=True, css=css) as demo:
139
  gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
140
  ],
141
  examples=[
142
- ['Are you a sentient being?']
 
 
 
 
143
  ],
144
- cache_examples=False,
145
- type='messages',
146
  )
147
 
148
  if __name__ == "__main__":
 
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
6
+ import torch
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
44
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
45
  model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
46
 
47
+ # Ensure we have a pad token
48
+ if tokenizer.pad_token_id is None:
49
+ tokenizer.pad_token_id = tokenizer.eos_token_id
50
+
51
  terminators = [
52
  tokenizer.eos_token_id,
53
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
 
93
  first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
94
  conversation.append({"role": "user", "content": first_message})
95
 
96
+ # Tokenize with padding and attention mask
97
+ input_data = tokenizer.apply_chat_template(conversation, return_tensors="pt", padding=True, truncation=True)
98
+ input_ids = input_data.to(model.device)
99
+
100
+ attention_mask = input_ids.ne(tokenizer.pad_token_id).to(dtype=torch.long, device=model.device)
101
+
102
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
103
 
104
  generate_kwargs = dict(
105
  input_ids=input_ids,
106
+ attention_mask=attention_mask, # Fixes the warning
107
  streamer=streamer,
108
  max_new_tokens=max_new_tokens,
109
  do_sample=True,
110
  temperature=temperature,
111
  top_p=top_p,
112
+ pad_token_id=tokenizer.pad_token_id, # Explicitly set
113
  eos_token_id=terminators,
114
  )
115
 
 
150
  gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
151
  ],
152
  examples=[
153
+ ['How to setup a human base on Mars? Give short answer.'],
154
+ ['Explain theory of relativity to me like I’m 8 years old.'],
155
+ ['What is 9,000 * 9,000?'],
156
+ ['Write a pun-filled happy birthday message to my friend Alex.'],
157
+ ['Justify why a penguin might make a good king of the jungle.']
158
  ],
159
+ cache_examples=False
 
160
  )
161
 
162
  if __name__ == "__main__":