OscarFAI commited on
Commit
47dded2
·
1 Parent(s): f1d7efb
Files changed (1) hide show
  1. app.py +19 -48
app.py CHANGED
@@ -3,14 +3,13 @@ import os
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
6
- import torch
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
  DESCRIPTION = '''
12
  <div>
13
- <h1 style="text-align: center;">Mistral 8B Instruct</h1>
14
  </div>
15
  '''
16
 
@@ -21,7 +20,7 @@ LICENSE = """
21
 
22
  PLACEHOLDER = """
23
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
24
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Mistral-8B</h1>
25
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
26
  </div>
27
  """
@@ -41,12 +40,8 @@ h1 {
41
  """
42
 
43
  # Load the tokenizer and model
44
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
45
- model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
46
-
47
- # Ensure we have a pad token
48
- if tokenizer.pad_token_id is None:
49
- tokenizer.pad_token_id = tokenizer.eos_token_id
50
 
51
  terminators = [
52
  tokenizer.eos_token_id,
@@ -54,19 +49,17 @@ terminators = [
54
  ]
55
 
56
  @spaces.GPU(duration=120)
57
- def chat_mistral(message: str,
58
- history: list,
59
- temperature: float,
60
- top_p: float,
61
- max_new_tokens: int,
62
- system_prompt: str) -> str:
63
  """
64
  Generate a streaming response using the Mistral-8B model.
65
  Args:
66
  message (str): The input message.
67
  history (list): The conversation history used by ChatInterface.
68
  temperature (float): The temperature for generating the response.
69
- top_p (float): The top-p (nucleus) sampling parameter.
70
  max_new_tokens (int): The maximum number of new tokens to generate.
71
  system_prompt (str): The system prompt to guide the assistant's behavior.
72
  Returns:
@@ -74,42 +67,25 @@ def chat_mistral(message: str,
74
  """
75
  conversation = []
76
 
77
- # Format system prompt correctly using [INST]
78
  if system_prompt:
79
- formatted_prompt = f"[INST] {system_prompt} [/INST]\n\n"
80
- else:
81
- formatted_prompt = ""
82
-
83
- # Modify first user message to include system prompt
84
- if history:
85
- first_user_msg = f"{formatted_prompt}{history[0][0]}" if formatted_prompt else history[0][0]
86
- conversation.append({"role": "user", "content": first_user_msg})
87
- conversation.append({"role": "assistant", "content": history[0][1]})
88
-
89
- for user, assistant in history[1:]:
90
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
91
- else:
92
- # First message in a new conversation
93
- first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
94
- conversation.append({"role": "user", "content": first_message})
95
 
96
- # Tokenize with padding and attention mask
97
- input_data = tokenizer.apply_chat_template(conversation, return_tensors="pt", padding=True, truncation=True)
98
- input_ids = input_data.to(model.device)
99
-
100
- attention_mask = input_ids.ne(tokenizer.pad_token_id).to(dtype=torch.long, device=model.device)
101
 
 
 
102
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
103
 
104
  generate_kwargs = dict(
105
  input_ids=input_ids,
106
- attention_mask=attention_mask, # Fixes the warning
107
  streamer=streamer,
108
  max_new_tokens=max_new_tokens,
109
  do_sample=True,
110
  temperature=temperature,
111
- top_p=top_p,
112
- pad_token_id=tokenizer.pad_token_id, # Explicitly set
113
  eos_token_id=terminators,
114
  )
115
 
@@ -139,22 +115,17 @@ with gr.Blocks(fill_height=True, css=css) as demo:
139
  )
140
 
141
  gr.ChatInterface(
142
- fn=chat_mistral,
143
  chatbot=chatbot,
144
  fill_height=True,
145
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
146
  additional_inputs=[
147
  system_prompt_input,
148
  gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
149
- gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="Top-p", render=False),
150
  gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
151
  ],
152
  examples=[
153
- ['How to setup a human base on Mars? Give short answer.'],
154
- ['Explain theory of relativity to me like I’m 8 years old.'],
155
- ['What is 9,000 * 9,000?'],
156
- ['Write a pun-filled happy birthday message to my friend Alex.'],
157
- ['Justify why a penguin might make a good king of the jungle.']
158
  ],
159
  cache_examples=False
160
  )
 
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from threading import Thread
 
6
 
7
  # Set an environment variable
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
 
10
  DESCRIPTION = '''
11
  <div>
12
+ <h1 style="text-align: center;">Mistral Chat</h1>
13
  </div>
14
  '''
15
 
 
20
 
21
  PLACEHOLDER = """
22
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
23
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Mistral Chat 8B</h1>
24
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
25
  </div>
26
  """
 
40
  """
41
 
42
  # Load the tokenizer and model
43
+ tokenizer = AutoTokenizer.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored")
44
+ model = AutoModelForCausalLM.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored", device_map="auto")
 
 
 
 
45
 
46
  terminators = [
47
  tokenizer.eos_token_id,
 
49
  ]
50
 
51
  @spaces.GPU(duration=120)
52
+ def chat_llama3_8b(message: str,
53
+ history: list,
54
+ temperature: float,
55
+ max_new_tokens: int,
56
+ system_prompt: str) -> str:
 
57
  """
58
  Generate a streaming response using the Mistral-8B model.
59
  Args:
60
  message (str): The input message.
61
  history (list): The conversation history used by ChatInterface.
62
  temperature (float): The temperature for generating the response.
 
63
  max_new_tokens (int): The maximum number of new tokens to generate.
64
  system_prompt (str): The system prompt to guide the assistant's behavior.
65
  Returns:
 
67
  """
68
  conversation = []
69
 
70
+ # Include system prompt at the beginning if provided
71
  if system_prompt:
72
+ conversation.append({"role": "system", "content": system_prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ for user, assistant in history:
75
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
76
+
77
+ conversation.append({"role": "user", "content": message})
 
78
 
79
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
80
+
81
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
82
 
83
  generate_kwargs = dict(
84
  input_ids=input_ids,
 
85
  streamer=streamer,
86
  max_new_tokens=max_new_tokens,
87
  do_sample=True,
88
  temperature=temperature,
 
 
89
  eos_token_id=terminators,
90
  )
91
 
 
115
  )
116
 
117
  gr.ChatInterface(
118
+ fn=chat_llama3_8b,
119
  chatbot=chatbot,
120
  fill_height=True,
121
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
122
  additional_inputs=[
123
  system_prompt_input,
124
  gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
 
125
  gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
126
  ],
127
  examples=[
128
+ ['Are you a sentient being?']
 
 
 
 
129
  ],
130
  cache_examples=False
131
  )