chegde commited on
Commit
2021bac
·
verified ·
1 Parent(s): f5e15bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -61
app.py CHANGED
@@ -5,22 +5,29 @@ import torch
5
  from threading import Thread
6
 
7
  veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B"
8
-
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
11
- # Try loading the model with explicit error handling
12
  try:
 
13
  veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
14
-
 
 
 
 
 
15
  veri_model = AutoModelForCausalLM.from_pretrained(
16
  veri_model_path,
17
- device_map="auto",
18
- torch_dtype="auto",
19
  trust_remote_code=True,
20
- use_cache=True, # Enable KV caching
21
- # attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
22
  )
23
-
 
 
24
  except Exception as e:
25
  print(f"Model loading error: {e}")
26
  veri_model = None
@@ -33,100 +40,260 @@ def truncate_at_code_end(text):
33
  end_index = text.find("CODE END") + len("CODE END")
34
  return text[:end_index].strip()
35
  return text.strip()
36
-
37
  def generate_response(user_message, history):
 
38
  if not veri_model or not veri_tokenizer:
39
  return history + [["Error", "Model not loaded properly"]]
40
 
41
  if not user_message.strip():
42
  return history
43
-
44
- # Simple generation without streaming first
45
- system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END."
46
 
 
 
 
47
  conversation = f"System: {system_message}\n"
48
  recent_history = history[-3:] if len(history) > 3 else history
49
-
50
  for h in recent_history:
51
  conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
52
  conversation += f"User: {user_message}\nAssistant:"
53
 
 
54
  inputs = veri_tokenizer(
55
- conversation,
56
  return_tensors="pt",
57
  truncation=True,
58
- max_length=8192,
59
- # padding=True,
60
- # return_attention_mask=True
61
  ).to(device)
62
 
 
63
  with torch.no_grad():
64
  outputs = veri_model.generate(
65
  **inputs,
66
- max_new_tokens=20000,
67
  temperature=0.6,
68
  top_p=0.95,
69
  do_sample=True,
70
- use_cache=True, # Enable KV caching for faster generation
71
- repetition_penalty=1.1, # Reduce repetition
72
- # length_penalty=1.0,
73
- # early_stopping=True, # Stop early when appropriate
74
- # num_beams=1, # Greedy search for speed
75
- # pad_token_id=veri_tokenizer.eos_token_id
76
  )
77
 
 
78
  response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
79
-
80
  # Truncate at CODE END to remove repetitive content
81
- # response = truncate_at_code_end(response)
82
 
83
-
84
  if torch.cuda.is_available():
85
  torch.cuda.empty_cache()
86
 
87
- # Return updated history
88
- return history + [[user_message, response.strip()]]
89
 
90
- # Create minimal interface
91
- with gr.Blocks(
92
- title="VeriThoughts-7B Chatbot",
93
- css="""
94
- .gradio-container {
95
- max-width: 1200px !important;
96
- }
97
- .chat-message {
98
- font-size: 14px;
99
- }
100
- """
101
- ) as demo:
102
- gr.Markdown(
103
- """
104
- # 🤖 VeriThoughts-7B Chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- An AI assistant specialized in Verilog coding and digital design.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- **Tips for better results:**
109
- - Mention input/output port names clearly
110
- - Ask for step-by-step explanations
111
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
 
114
- chatbot = gr.Chatbot(value=[], label="Chat")
115
- msg = gr.Textbox(label="Your message", placeholder="Ask me about Verilog design, syntax, or implementation...")
116
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
117
 
118
- # Simple event handling
119
- msg.submit(
120
- fn=generate_response,
121
- inputs=[msg, chatbot],
122
- outputs=chatbot
 
123
  ).then(
124
- lambda: "",
125
- inputs=None,
126
  outputs=msg
127
  )
128
 
129
- clear.click(lambda: [], outputs=chatbot)
 
 
 
 
 
130
 
131
- # Launch without ssr_mode parameter which might cause issues
132
  demo.launch(share=True)
 
5
  from threading import Thread
6
 
7
  veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B"
 
8
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Try loading the model with KV caching (no flash attention or quantization)
11
  try:
12
+ print("Loading tokenizer...")
13
  veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
14
+
15
+ # Set pad token if not exists
16
+ if veri_tokenizer.pad_token is None:
17
+ veri_tokenizer.pad_token = veri_tokenizer.eos_token
18
+
19
+ print("Loading model with KV caching...")
20
  veri_model = AutoModelForCausalLM.from_pretrained(
21
  veri_model_path,
22
+ device_map="auto" if torch.cuda.is_available() else None,
23
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
24
  trust_remote_code=True,
25
+ use_cache=True, # Enable KV caching for faster generation
26
+ low_cpu_mem_usage=True
27
  )
28
+
29
+ print("Model loaded successfully with KV caching!")
30
+
31
  except Exception as e:
32
  print(f"Model loading error: {e}")
33
  veri_model = None
 
40
  end_index = text.find("CODE END") + len("CODE END")
41
  return text[:end_index].strip()
42
  return text.strip()
43
+
44
  def generate_response(user_message, history):
45
+ """Non-streaming generation for quick responses"""
46
  if not veri_model or not veri_tokenizer:
47
  return history + [["Error", "Model not loaded properly"]]
48
 
49
  if not user_message.strip():
50
  return history
 
 
 
51
 
52
+ system_message = "You are VeriThoughts, a helpful assistant that thinks step by step to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END."
53
+
54
+ # Create conversation history (limit to last 3 exchanges for memory efficiency)
55
  conversation = f"System: {system_message}\n"
56
  recent_history = history[-3:] if len(history) > 3 else history
57
+
58
  for h in recent_history:
59
  conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
60
  conversation += f"User: {user_message}\nAssistant:"
61
 
62
+ # Tokenize input
63
  inputs = veri_tokenizer(
64
+ conversation,
65
  return_tensors="pt",
66
  truncation=True,
67
+ max_length=4096,
68
+ padding=True
 
69
  ).to(device)
70
 
71
+ # Generate with KV caching
72
  with torch.no_grad():
73
  outputs = veri_model.generate(
74
  **inputs,
75
+ max_new_tokens=1024,
76
  temperature=0.6,
77
  top_p=0.95,
78
  do_sample=True,
79
+ pad_token_id=veri_tokenizer.pad_token_id,
80
+ eos_token_id=veri_tokenizer.eos_token_id,
81
+ use_cache=True, # KV caching for speed
82
+ repetition_penalty=1.1,
83
+ early_stopping=True
 
84
  )
85
 
86
+ # Decode response
87
  response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
88
+
89
  # Truncate at CODE END to remove repetitive content
90
+ response = truncate_at_code_end(response)
91
 
92
+ # Clean up GPU memory
93
  if torch.cuda.is_available():
94
  torch.cuda.empty_cache()
95
 
96
+ return history + [[user_message, response]]
 
97
 
98
+ @spaces.GPU(duration=120)
99
+ def generate_response_streaming(user_message, history):
100
+ """Streaming generation for real-time response display"""
101
+ if not veri_model or not veri_tokenizer:
102
+ yield history + [["Error", "Model not loaded properly"]]
103
+ return
104
+
105
+ if not user_message.strip():
106
+ yield history
107
+ return
108
+
109
+ system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud, to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END."
110
+
111
+ # Create conversation history (limit for memory efficiency)
112
+ conversation = f"System: {system_message}\n"
113
+ recent_history = history[-3:] if len(history) > 3 else history
114
+
115
+ for h in recent_history:
116
+ conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
117
+ conversation += f"User: {user_message}\nAssistant:"
118
+
119
+ try:
120
+ # Tokenize input
121
+ inputs = veri_tokenizer(
122
+ conversation,
123
+ return_tensors="pt",
124
+ truncation=True,
125
+ max_length=2048,
126
+ padding=True
127
+ ).to(device)
128
+
129
+ # Setup streaming
130
+ streamer = TextIteratorStreamer(
131
+ veri_tokenizer,
132
+ skip_prompt=True,
133
+ skip_special_tokens=True,
134
+ timeout=30.0
135
+ )
136
+
137
+ # Generation parameters with KV caching
138
+ generation_kwargs = {
139
+ **inputs,
140
+ "max_new_tokens": 4096,
141
+ "temperature": 0.6,
142
+ "top_p": 0.95,
143
+ "do_sample": True,
144
+ "pad_token_id": veri_tokenizer.pad_token_id,
145
+ "eos_token_id": veri_tokenizer.eos_token_id,
146
+ "use_cache": True, # KV caching for faster streaming
147
+ "repetition_penalty": 1.1,
148
+ "streamer": streamer,
149
+ "early_stopping": True
150
+ }
151
+
152
+ # Start generation in a separate thread
153
+ thread = Thread(target=veri_model.generate, kwargs=generation_kwargs)
154
+ thread.start()
155
+
156
+ # Stream the response token by token
157
+ generated_text = ""
158
+ new_history = history + [[user_message, ""]]
159
+ code_end_reached = False
160
 
161
+ for new_text in streamer:
162
+ # Stop streaming if we've already reached CODE END
163
+ if code_end_reached:
164
+ break
165
+
166
+ generated_text += new_text
167
+
168
+ # Check if CODE END appears in the generated text
169
+ if "CODE END" in generated_text:
170
+ # Truncate at CODE END and mark as complete
171
+ generated_text = truncate_at_code_end(generated_text)
172
+ code_end_reached = True
173
+
174
+ new_history[-1][1] = generated_text
175
+ yield new_history
176
+
177
+ # Break early if CODE END was reached
178
+ if code_end_reached:
179
+ break
180
 
181
+ # Ensure the thread completes
182
+ thread.join()
183
+
184
+ # Final cleanup in case CODE END wasn't reached during streaming
185
+ if not code_end_reached:
186
+ final_text = truncate_at_code_end(generated_text)
187
+ new_history[-1][1] = final_text
188
+ yield new_history
189
+
190
+ except Exception as e:
191
+ print(f"Streaming error: {e}")
192
+ error_history = history + [[user_message, f"Streaming error: {str(e)}"]]
193
+ yield error_history
194
+
195
+ finally:
196
+ # Clean up GPU memory after generation
197
+ if torch.cuda.is_available():
198
+ torch.cuda.empty_cache()
199
+
200
+ def clear_chat():
201
+ """Clear chat and clean up memory"""
202
+ if torch.cuda.is_available():
203
+ torch.cuda.empty_cache()
204
+ return []
205
+
206
+ # Create interface with soft theme
207
+ with gr.Blocks(title="VeriThoughts-7B Chatbot") as demo:
208
+ gr.Markdown("# VeriThoughts-7B Chatbot")
209
+ gr.Markdown("*Optimized with KV caching for faster generation*")
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=4):
213
+ chatbot = gr.Chatbot(
214
+ value=[],
215
+ label="Chat",
216
+ height=600,
217
+ show_label=False,
218
+ container=True
219
+ )
220
+
221
+ with gr.Row():
222
+ msg = gr.Textbox(
223
+ label="Your message",
224
+ placeholder="Ask me about Verilog design, syntax, or implementation...",
225
+ lines=2,
226
+ max_lines=5,
227
+ scale=4
228
+ )
229
+ send_btn = gr.Button("Send", variant="primary", scale=1)
230
+
231
+ with gr.Column(scale=1):
232
+ with gr.Group():
233
+ stream_btn = gr.Button("📡 Send (Streaming)", variant="secondary", size="sm")
234
+ clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", size="sm")
235
+
236
+ gr.Markdown(
237
+ """
238
+ ### 💡 Usage Tips
239
+
240
+ **Send**: Quick response (max 1K tokens)
241
+ **Streaming**: Real-time response (max 2K tokens)
242
+
243
+ ### ⚡ Optimizations Active
244
+ - **KV Caching**: Faster token generation
245
+ - **Memory Management**: Auto cleanup
246
+ - **Context Limiting**: Recent history only
247
+
248
+ ### 🎯 Best Practices
249
+ - Be specific about Verilog requirements
250
+ - Mention input/output port names
251
+ - Ask for step-by-step explanations
252
+ - Clear chat periodically
253
+ """
254
+ )
255
+
256
+ # Event handlers for regular send
257
+ submit_event = msg.submit(
258
+ fn=generate_response,
259
+ inputs=[msg, chatbot],
260
+ outputs=chatbot,
261
+ show_progress=True
262
+ ).then(
263
+ lambda: "",
264
+ inputs=None,
265
+ outputs=msg
266
  )
267
 
268
+ send_btn.click(
269
+ fn=generate_response,
270
+ inputs=[msg, chatbot],
271
+ outputs=chatbot,
272
+ show_progress=True
273
+ ).then(
274
+ lambda: "",
275
+ inputs=None,
276
+ outputs=msg
277
+ )
278
 
279
+ # Event handler for streaming
280
+ stream_btn.click(
281
+ fn=generate_response_streaming,
282
+ inputs=[msg, chatbot],
283
+ outputs=chatbot,
284
+ show_progress=True
285
  ).then(
286
+ lambda: "",
287
+ inputs=None,
288
  outputs=msg
289
  )
290
 
291
+ # Clear chat handler
292
+ clear_btn.click(
293
+ fn=clear_chat,
294
+ inputs=None,
295
+ outputs=chatbot
296
+ )
297
 
298
+ # Launch the app
299
  demo.launch(share=True)