Daemontatox commited on
Commit
b70c257
·
verified ·
1 Parent(s): 99f29ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -129
app.py CHANGED
@@ -1,23 +1,30 @@
 
 
 
 
 
 
 
1
  import os
 
2
  import time
3
- import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
  from threading import Thread
 
 
 
 
 
 
8
 
9
- MODEL_LIST = ["CohereForAI/aya-expanse-32b"]
10
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL = "CohereForAI/aya-expanse-32b"
12
-
13
- TITLE = "<h1><center>Mawred T2 Wip </center></h1>"
14
-
15
- PLACEHOLDER = """
16
- <center>
17
- <p>Hi! How can I help you today?</p>
18
- </center>
19
- """
20
-
21
 
22
  CSS = """
23
  .duplicate-button {
@@ -29,150 +36,299 @@ CSS = """
29
  h3 {
30
  text-align: center;
31
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
 
34
- device = "cuda" # for GPU usage or "cpu" for CPU usage
 
 
 
 
 
 
 
 
 
 
35
 
36
- quantization_config = BitsAndBytesConfig(
37
- load_in_4bit=True,
38
- bnb_4bit_compute_dtype=torch.bfloat16,
39
- bnb_4bit_use_double_quant=True,
40
- bnb_4bit_quant_type= "nf4")
 
41
 
42
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
43
- model = AutoModelForCausalLM.from_pretrained(
44
- MODEL,
45
- torch_dtype=torch.bfloat16,
46
- device_map="auto",
47
- quantization_config=quantization_config
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- @spaces.GPU(660)
51
- def stream_chat(
52
- message: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  history: list,
 
54
  system_prompt: str,
55
- temperature: float = 0.8,
56
- max_new_tokens: int = 1024,
57
- top_p: float = 1.0,
58
- top_k: int = 20,
59
  penalty: float = 1.2,
60
  ):
61
- print(f'message: {message}')
62
- print(f'history: {history}')
63
-
64
  conversation = [
65
  {"role": "system", "content": system_prompt}
66
  ]
 
67
  for prompt, answer in history:
68
  conversation.extend([
69
- {"role": "user", "content": prompt},
70
- {"role": "assistant", "content": answer},
71
  ])
72
-
73
  conversation.append({"role": "user", "content": message})
74
-
75
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
76
 
77
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  generate_kwargs = dict(
80
- input_ids=input_ids,
81
- max_new_tokens = max_new_tokens,
82
- do_sample = False if temperature == 0 else True,
83
- top_p = top_p,
84
- top_k = top_k,
85
- temperature = temperature,
86
  repetition_penalty=penalty,
87
- eos_token_id=255001,
88
  streamer=streamer,
89
  )
90
-
 
 
91
  with torch.no_grad():
92
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
93
  thread.start()
94
 
95
- buffer = ""
96
- for new_text in streamer:
97
- buffer += new_text
98
- yield buffer
99
-
 
 
100
 
101
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
102
 
103
- with gr.Blocks(css=CSS, theme="soft") as demo:
104
- gr.HTML(TITLE)
105
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
106
- gr.ChatInterface(
107
- fn=stream_chat,
108
- chatbot=chatbot,
109
- fill_height=True,
110
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
111
- additional_inputs=[
112
- gr.Textbox(
113
- value="""
114
- You are a helpful assistant.
115
- """,
116
- label="System Prompt",
117
- lines=5,
118
- render=False,
119
- ),
120
- gr.Slider(
121
- minimum=0,
122
- maximum=1,
123
- step=0.1,
124
- value=0.8,
125
- label="Temperature",
126
- render=False,
127
- ),
128
- gr.Slider(
129
- minimum=128,
130
- maximum=8192,
131
- step=1,
132
- value=1024,
133
- label="Max new tokens",
134
- render=False,
135
- ),
136
- gr.Slider(
137
- minimum=0.0,
138
- maximum=1.0,
139
- step=0.1,
140
- value=1.0,
141
- label="top_p",
142
- render=False,
143
- ),
144
- gr.Slider(
145
- minimum=1,
146
- maximum=20,
147
- step=1,
148
- value=20,
149
- label="top_k",
150
- render=False,
151
- ),
152
- gr.Slider(
153
- minimum=0.0,
154
- maximum=2.0,
155
- step=0.1,
156
- value=1.2,
157
- label="Repetition penalty",
158
- render=False,
159
- ),
160
- ],
161
- examples=[
162
- ["Translate 'artificial intelligence' to Arabic."],
163
- ["How do you say 'photosynthesis' in Arabic?"],
164
- ["Translate 'main causes of climate change' into Arabic."],
165
- ["What is the Arabic translation for 'protein synthesis'?"],
166
- ["Translate 'key features of a democratic government' to Arabic."],
167
- ["How do you translate 'theory of relativity' into Arabic?"],
168
- ["What is the Arabic equivalent of 'vaccines prevent diseases'?"],
169
- ["Translate 'major events of World War II' to Arabic."],
170
- ["How do you say 'structure of a human cell' in Arabic?"],
171
- ["Translate 'role of DNA in genetics' into Arabic."]
172
- ],
173
- cache_examples=False,
174
- )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  if __name__ == "__main__":
 
178
  demo.launch()
 
1
+ import subprocess
2
+
3
+ subprocess.run(
4
+ 'pip install flash-attn --no-build-isolation',
5
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
6
+ shell=True
7
+ )
8
  import os
9
+ import re
10
  import time
 
11
  import torch
12
+ import spaces
13
  import gradio as gr
14
  from threading import Thread
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ BitsAndBytesConfig,
19
+ TextIteratorStreamer
20
+ )
21
 
22
+ # Configuration Constants
23
+ MODEL_ID = "CohereForAI/aya-expanse-32b"
24
+ DEFAULT_SYSTEM_PROMPT = """You are a highly intelligent assistant."""
25
+ # UI Configuration
26
+ TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
27
+ PLACEHOLDER = "Ask me anything! I'll think through it step by step."
 
 
 
 
 
 
28
 
29
  CSS = """
30
  .duplicate-button {
 
36
  h3 {
37
  text-align: center;
38
  }
39
+ .message-wrap {
40
+ overflow-x: auto;
41
+ }
42
+ .message-wrap p {
43
+ margin-bottom: 1em;
44
+ }
45
+ .message-wrap pre {
46
+ background-color: #f6f8fa;
47
+ border-radius: 3px;
48
+ padding: 16px;
49
+ overflow-x: auto;
50
+ }
51
+ .message-wrap code {
52
+ background-color: rgba(175,184,193,0.2);
53
+ border-radius: 3px;
54
+ padding: 0.2em 0.4em;
55
+ font-family: monospace;
56
+ }
57
+ .custom-tag {
58
+ color: #0066cc;
59
+ font-weight: bold;
60
+ }
61
+ .chat-area {
62
+ height: 500px !important;
63
+ overflow-y: auto !important;
64
+ }
65
  """
66
 
67
+ def initialize_model():
68
+ """Initialize the model with appropriate configurations"""
69
+ quantization_config = BitsAndBytesConfig(
70
+ load_in_4bit=True,
71
+ bnb_4bit_compute_dtype=torch.bfloat16,
72
+ bnb_4bit_use_double_quant=True
73
+ )
74
+
75
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
76
+ if tokenizer.pad_token_id is None:
77
+ tokenizer.pad_token_id = tokenizer.eos_token_id
78
 
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ MODEL_ID,
81
+ torch_dtype=torch.float16,
82
+ device_map="cuda",
83
+ attn_implementation="flash_attention_2",
84
+ quantization_config=quantization_config
85
 
86
+ )
87
+
88
+ return model, tokenizer
89
+
90
+ def format_text(text):
91
+ """Format text with proper spacing and tag highlighting (but keep tags visible)"""
92
+ tag_patterns = [
93
+ (r'<Thinking>', '\n<Thinking>\n'),
94
+ (r'</Thinking>', '\n</Thinking>\n'),
95
+ (r'<Critique>', '\n<Critique>\n'),
96
+ (r'</Critique>', '\n</Critique>\n'),
97
+ (r'<Revising>', '\n<Revising>\n'),
98
+ (r'</Revising>', '\n</Revising>\n'),
99
+ (r'<Final>', '\n<Final>\n'),
100
+ (r'</Final>', '\n</Final>\n')
101
+ ]
102
+
103
+ formatted = text
104
+ for pattern, replacement in tag_patterns:
105
+ formatted = re.sub(pattern, replacement, formatted)
106
+
107
+ formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
108
+
109
+ return formatted
110
 
111
+ def format_chat_history(history):
112
+ """Format chat history for display, keeping tags visible"""
113
+ formatted = []
114
+ for user_msg, assistant_msg in history:
115
+ formatted.append(f"User: {user_msg}")
116
+ if assistant_msg:
117
+ formatted.append(f"Assistant: {assistant_msg}")
118
+ return "\n\n".join(formatted)
119
+
120
+ def create_examples():
121
+ """Create example queries for the UI"""
122
+ return [
123
+ "Explain the concept of artificial intelligence.",
124
+ "How does photosynthesis work?",
125
+ "What are the main causes of climate change?",
126
+ "Describe the process of protein synthesis.",
127
+ "What are the key features of a democratic government?",
128
+ "Explain the theory of relativity.",
129
+ "How do vaccines work to prevent diseases?",
130
+ "What are the major events of World War II?",
131
+ "Describe the structure of a human cell.",
132
+ "What is the role of DNA in genetics?"
133
+ ]
134
+
135
+ @spaces.GPU(duration=660)
136
+ def chat_response(
137
+ message: str,
138
  history: list,
139
+ chat_display: str,
140
  system_prompt: str,
141
+ temperature: float = 1.0,
142
+ max_new_tokens: int = 4000,
143
+ top_p: float = 0.8,
144
+ top_k: int = 40,
145
  penalty: float = 1.2,
146
  ):
147
+ """Generate chat responses, keeping tags visible in the output"""
 
 
148
  conversation = [
149
  {"role": "system", "content": system_prompt}
150
  ]
151
+
152
  for prompt, answer in history:
153
  conversation.extend([
154
+ {"role": "user", "content": prompt},
155
+ {"role": "assistant", "content": answer}
156
  ])
157
+
158
  conversation.append({"role": "user", "content": message})
 
 
159
 
160
+ input_ids = tokenizer.apply_chat_template(
161
+ conversation,
162
+ add_generation_prompt=True,
163
+ return_tensors="pt"
164
+ ).to(model.device)
165
+
166
+ streamer = TextIteratorStreamer(
167
+ tokenizer,
168
+ timeout=60.0,
169
+ skip_prompt=True,
170
+ skip_special_tokens=True
171
+ )
172
 
173
  generate_kwargs = dict(
174
+ input_ids=input_ids,
175
+ max_new_tokens=max_new_tokens,
176
+ do_sample=False if temperature == 0 else True,
177
+ top_p=top_p,
178
+ top_k=top_k,
179
+ temperature=temperature,
180
  repetition_penalty=penalty,
 
181
  streamer=streamer,
182
  )
183
+
184
+ buffer = ""
185
+
186
  with torch.no_grad():
187
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
188
  thread.start()
189
 
190
+ history = history + [[message, ""]]
191
+
192
+ for new_text in streamer:
193
+ buffer += new_text
194
+ formatted_buffer = format_text(buffer)
195
+ history[-1][1] = formatted_buffer
196
+ chat_display = format_chat_history(history)
197
 
198
+ yield history, chat_display
199
 
200
+ def process_example(example: str) -> tuple:
201
+ """Process example query and return empty history and updated display"""
202
+ return [], f"User: {example}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ def main():
205
+ """Main function to set up and launch the Gradio interface"""
206
+ global model, tokenizer
207
+ model, tokenizer = initialize_model()
208
+
209
+ with gr.Blocks(css=CSS, theme="soft") as demo:
210
+ gr.HTML(TITLE)
211
+ gr.DuplicateButton(
212
+ value="Duplicate Space for private use",
213
+ elem_classes="duplicate-button"
214
+ )
215
+
216
+ with gr.Row():
217
+ with gr.Column():
218
+ chat_history = gr.State([])
219
+ chat_display = gr.TextArea(
220
+ value="",
221
+ label="Chat History",
222
+ interactive=False,
223
+ elem_classes=["chat-area"],
224
+ )
225
+
226
+ message = gr.TextArea(
227
+ placeholder=PLACEHOLDER,
228
+ label="Your message",
229
+ lines=3
230
+ )
231
+
232
+ with gr.Row():
233
+ submit = gr.Button("Send")
234
+ clear = gr.Button("Clear")
235
+
236
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
237
+ system_prompt = gr.TextArea(
238
+ value=DEFAULT_SYSTEM_PROMPT,
239
+ label="System Prompt",
240
+ lines=5,
241
+ )
242
+ temperature = gr.Slider(
243
+ minimum=0,
244
+ maximum=1,
245
+ step=0.1,
246
+ value=0.2,
247
+ label="Temperature",
248
+ )
249
+ max_tokens = gr.Slider(
250
+ minimum=128,
251
+ maximum=32000,
252
+ step=128,
253
+ value=4000,
254
+ label="Max Tokens",
255
+ )
256
+ top_p = gr.Slider(
257
+ minimum=0.1,
258
+ maximum=1.0,
259
+ step=0.1,
260
+ value=0.8,
261
+ label="Top-p",
262
+ )
263
+ top_k = gr.Slider(
264
+ minimum=1,
265
+ maximum=100,
266
+ step=1,
267
+ value=40,
268
+ label="Top-k",
269
+ )
270
+ penalty = gr.Slider(
271
+ minimum=1.0,
272
+ maximum=2.0,
273
+ step=0.1,
274
+ value=1.2,
275
+ label="Repetition Penalty",
276
+ )
277
+
278
+ examples = gr.Examples(
279
+ examples=create_examples(),
280
+ inputs=[message],
281
+ outputs=[chat_history, chat_display],
282
+ fn=process_example,
283
+ cache_examples=False,
284
+ )
285
+
286
+ # Set up event handlers
287
+ submit_click = submit.click(
288
+ chat_response,
289
+ inputs=[
290
+ message,
291
+ chat_history,
292
+ chat_display,
293
+ system_prompt,
294
+ temperature,
295
+ max_tokens,
296
+ top_p,
297
+ top_k,
298
+ penalty,
299
+ ],
300
+ outputs=[chat_history, chat_display],
301
+ show_progress=True,
302
+ )
303
+
304
+ message.submit(
305
+ chat_response,
306
+ inputs=[
307
+ message,
308
+ chat_history,
309
+ chat_display,
310
+ system_prompt,
311
+ temperature,
312
+ max_tokens,
313
+ top_p,
314
+ top_k,
315
+ penalty,
316
+ ],
317
+ outputs=[chat_history, chat_display],
318
+ show_progress=True,
319
+ )
320
+
321
+ clear.click(
322
+ lambda: ([], ""),
323
+ outputs=[chat_history, chat_display],
324
+ show_progress=True,
325
+ )
326
+
327
+ submit_click.then(lambda: "", outputs=message)
328
+ message.submit(lambda: "", outputs=message)
329
+
330
+ return demo
331
 
332
  if __name__ == "__main__":
333
+ demo = main()
334
  demo.launch()