Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,22 +5,29 @@ import torch
|
|
5 |
from threading import Thread
|
6 |
|
7 |
veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B"
|
8 |
-
|
9 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
|
11 |
-
# Try loading the model with
|
12 |
try:
|
|
|
13 |
veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
15 |
veri_model = AutoModelForCausalLM.from_pretrained(
|
16 |
veri_model_path,
|
17 |
-
device_map="auto",
|
18 |
-
torch_dtype=
|
19 |
trust_remote_code=True,
|
20 |
-
use_cache=True, # Enable KV caching
|
21 |
-
|
22 |
)
|
23 |
-
|
|
|
|
|
24 |
except Exception as e:
|
25 |
print(f"Model loading error: {e}")
|
26 |
veri_model = None
|
@@ -33,100 +40,260 @@ def truncate_at_code_end(text):
|
|
33 |
end_index = text.find("CODE END") + len("CODE END")
|
34 |
return text[:end_index].strip()
|
35 |
return text.strip()
|
36 |
-
|
37 |
def generate_response(user_message, history):
|
|
|
38 |
if not veri_model or not veri_tokenizer:
|
39 |
return history + [["Error", "Model not loaded properly"]]
|
40 |
|
41 |
if not user_message.strip():
|
42 |
return history
|
43 |
-
|
44 |
-
# Simple generation without streaming first
|
45 |
-
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END."
|
46 |
|
|
|
|
|
|
|
47 |
conversation = f"System: {system_message}\n"
|
48 |
recent_history = history[-3:] if len(history) > 3 else history
|
49 |
-
|
50 |
for h in recent_history:
|
51 |
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
|
52 |
conversation += f"User: {user_message}\nAssistant:"
|
53 |
|
|
|
54 |
inputs = veri_tokenizer(
|
55 |
-
conversation,
|
56 |
return_tensors="pt",
|
57 |
truncation=True,
|
58 |
-
max_length=
|
59 |
-
|
60 |
-
# return_attention_mask=True
|
61 |
).to(device)
|
62 |
|
|
|
63 |
with torch.no_grad():
|
64 |
outputs = veri_model.generate(
|
65 |
**inputs,
|
66 |
-
max_new_tokens=
|
67 |
temperature=0.6,
|
68 |
top_p=0.95,
|
69 |
do_sample=True,
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
# pad_token_id=veri_tokenizer.eos_token_id
|
76 |
)
|
77 |
|
|
|
78 |
response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
79 |
-
|
80 |
# Truncate at CODE END to remove repetitive content
|
81 |
-
|
82 |
|
83 |
-
|
84 |
if torch.cuda.is_available():
|
85 |
torch.cuda.empty_cache()
|
86 |
|
87 |
-
|
88 |
-
return history + [[user_message, response.strip()]]
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
.
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
#
|
119 |
-
|
120 |
-
fn=
|
121 |
-
inputs=[msg, chatbot],
|
122 |
-
outputs=chatbot
|
|
|
123 |
).then(
|
124 |
-
lambda: "",
|
125 |
-
inputs=None,
|
126 |
outputs=msg
|
127 |
)
|
128 |
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
# Launch
|
132 |
demo.launch(share=True)
|
|
|
5 |
from threading import Thread
|
6 |
|
7 |
veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B"
|
|
|
8 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
9 |
|
10 |
+
# Try loading the model with KV caching (no flash attention or quantization)
|
11 |
try:
|
12 |
+
print("Loading tokenizer...")
|
13 |
veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
|
14 |
+
|
15 |
+
# Set pad token if not exists
|
16 |
+
if veri_tokenizer.pad_token is None:
|
17 |
+
veri_tokenizer.pad_token = veri_tokenizer.eos_token
|
18 |
+
|
19 |
+
print("Loading model with KV caching...")
|
20 |
veri_model = AutoModelForCausalLM.from_pretrained(
|
21 |
veri_model_path,
|
22 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
23 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
24 |
trust_remote_code=True,
|
25 |
+
use_cache=True, # Enable KV caching for faster generation
|
26 |
+
low_cpu_mem_usage=True
|
27 |
)
|
28 |
+
|
29 |
+
print("Model loaded successfully with KV caching!")
|
30 |
+
|
31 |
except Exception as e:
|
32 |
print(f"Model loading error: {e}")
|
33 |
veri_model = None
|
|
|
40 |
end_index = text.find("CODE END") + len("CODE END")
|
41 |
return text[:end_index].strip()
|
42 |
return text.strip()
|
43 |
+
|
44 |
def generate_response(user_message, history):
|
45 |
+
"""Non-streaming generation for quick responses"""
|
46 |
if not veri_model or not veri_tokenizer:
|
47 |
return history + [["Error", "Model not loaded properly"]]
|
48 |
|
49 |
if not user_message.strip():
|
50 |
return history
|
|
|
|
|
|
|
51 |
|
52 |
+
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END."
|
53 |
+
|
54 |
+
# Create conversation history (limit to last 3 exchanges for memory efficiency)
|
55 |
conversation = f"System: {system_message}\n"
|
56 |
recent_history = history[-3:] if len(history) > 3 else history
|
57 |
+
|
58 |
for h in recent_history:
|
59 |
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
|
60 |
conversation += f"User: {user_message}\nAssistant:"
|
61 |
|
62 |
+
# Tokenize input
|
63 |
inputs = veri_tokenizer(
|
64 |
+
conversation,
|
65 |
return_tensors="pt",
|
66 |
truncation=True,
|
67 |
+
max_length=4096,
|
68 |
+
padding=True
|
|
|
69 |
).to(device)
|
70 |
|
71 |
+
# Generate with KV caching
|
72 |
with torch.no_grad():
|
73 |
outputs = veri_model.generate(
|
74 |
**inputs,
|
75 |
+
max_new_tokens=1024,
|
76 |
temperature=0.6,
|
77 |
top_p=0.95,
|
78 |
do_sample=True,
|
79 |
+
pad_token_id=veri_tokenizer.pad_token_id,
|
80 |
+
eos_token_id=veri_tokenizer.eos_token_id,
|
81 |
+
use_cache=True, # KV caching for speed
|
82 |
+
repetition_penalty=1.1,
|
83 |
+
early_stopping=True
|
|
|
84 |
)
|
85 |
|
86 |
+
# Decode response
|
87 |
response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
88 |
+
|
89 |
# Truncate at CODE END to remove repetitive content
|
90 |
+
response = truncate_at_code_end(response)
|
91 |
|
92 |
+
# Clean up GPU memory
|
93 |
if torch.cuda.is_available():
|
94 |
torch.cuda.empty_cache()
|
95 |
|
96 |
+
return history + [[user_message, response]]
|
|
|
97 |
|
98 |
+
@spaces.GPU(duration=120)
|
99 |
+
def generate_response_streaming(user_message, history):
|
100 |
+
"""Streaming generation for real-time response display"""
|
101 |
+
if not veri_model or not veri_tokenizer:
|
102 |
+
yield history + [["Error", "Model not loaded properly"]]
|
103 |
+
return
|
104 |
+
|
105 |
+
if not user_message.strip():
|
106 |
+
yield history
|
107 |
+
return
|
108 |
+
|
109 |
+
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud, to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END."
|
110 |
+
|
111 |
+
# Create conversation history (limit for memory efficiency)
|
112 |
+
conversation = f"System: {system_message}\n"
|
113 |
+
recent_history = history[-3:] if len(history) > 3 else history
|
114 |
+
|
115 |
+
for h in recent_history:
|
116 |
+
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
|
117 |
+
conversation += f"User: {user_message}\nAssistant:"
|
118 |
+
|
119 |
+
try:
|
120 |
+
# Tokenize input
|
121 |
+
inputs = veri_tokenizer(
|
122 |
+
conversation,
|
123 |
+
return_tensors="pt",
|
124 |
+
truncation=True,
|
125 |
+
max_length=2048,
|
126 |
+
padding=True
|
127 |
+
).to(device)
|
128 |
+
|
129 |
+
# Setup streaming
|
130 |
+
streamer = TextIteratorStreamer(
|
131 |
+
veri_tokenizer,
|
132 |
+
skip_prompt=True,
|
133 |
+
skip_special_tokens=True,
|
134 |
+
timeout=30.0
|
135 |
+
)
|
136 |
+
|
137 |
+
# Generation parameters with KV caching
|
138 |
+
generation_kwargs = {
|
139 |
+
**inputs,
|
140 |
+
"max_new_tokens": 4096,
|
141 |
+
"temperature": 0.6,
|
142 |
+
"top_p": 0.95,
|
143 |
+
"do_sample": True,
|
144 |
+
"pad_token_id": veri_tokenizer.pad_token_id,
|
145 |
+
"eos_token_id": veri_tokenizer.eos_token_id,
|
146 |
+
"use_cache": True, # KV caching for faster streaming
|
147 |
+
"repetition_penalty": 1.1,
|
148 |
+
"streamer": streamer,
|
149 |
+
"early_stopping": True
|
150 |
+
}
|
151 |
+
|
152 |
+
# Start generation in a separate thread
|
153 |
+
thread = Thread(target=veri_model.generate, kwargs=generation_kwargs)
|
154 |
+
thread.start()
|
155 |
+
|
156 |
+
# Stream the response token by token
|
157 |
+
generated_text = ""
|
158 |
+
new_history = history + [[user_message, ""]]
|
159 |
+
code_end_reached = False
|
160 |
|
161 |
+
for new_text in streamer:
|
162 |
+
# Stop streaming if we've already reached CODE END
|
163 |
+
if code_end_reached:
|
164 |
+
break
|
165 |
+
|
166 |
+
generated_text += new_text
|
167 |
+
|
168 |
+
# Check if CODE END appears in the generated text
|
169 |
+
if "CODE END" in generated_text:
|
170 |
+
# Truncate at CODE END and mark as complete
|
171 |
+
generated_text = truncate_at_code_end(generated_text)
|
172 |
+
code_end_reached = True
|
173 |
+
|
174 |
+
new_history[-1][1] = generated_text
|
175 |
+
yield new_history
|
176 |
+
|
177 |
+
# Break early if CODE END was reached
|
178 |
+
if code_end_reached:
|
179 |
+
break
|
180 |
|
181 |
+
# Ensure the thread completes
|
182 |
+
thread.join()
|
183 |
+
|
184 |
+
# Final cleanup in case CODE END wasn't reached during streaming
|
185 |
+
if not code_end_reached:
|
186 |
+
final_text = truncate_at_code_end(generated_text)
|
187 |
+
new_history[-1][1] = final_text
|
188 |
+
yield new_history
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Streaming error: {e}")
|
192 |
+
error_history = history + [[user_message, f"Streaming error: {str(e)}"]]
|
193 |
+
yield error_history
|
194 |
+
|
195 |
+
finally:
|
196 |
+
# Clean up GPU memory after generation
|
197 |
+
if torch.cuda.is_available():
|
198 |
+
torch.cuda.empty_cache()
|
199 |
+
|
200 |
+
def clear_chat():
|
201 |
+
"""Clear chat and clean up memory"""
|
202 |
+
if torch.cuda.is_available():
|
203 |
+
torch.cuda.empty_cache()
|
204 |
+
return []
|
205 |
+
|
206 |
+
# Create interface with soft theme
|
207 |
+
with gr.Blocks(title="VeriThoughts-7B Chatbot") as demo:
|
208 |
+
gr.Markdown("# VeriThoughts-7B Chatbot")
|
209 |
+
gr.Markdown("*Optimized with KV caching for faster generation*")
|
210 |
+
|
211 |
+
with gr.Row():
|
212 |
+
with gr.Column(scale=4):
|
213 |
+
chatbot = gr.Chatbot(
|
214 |
+
value=[],
|
215 |
+
label="Chat",
|
216 |
+
height=600,
|
217 |
+
show_label=False,
|
218 |
+
container=True
|
219 |
+
)
|
220 |
+
|
221 |
+
with gr.Row():
|
222 |
+
msg = gr.Textbox(
|
223 |
+
label="Your message",
|
224 |
+
placeholder="Ask me about Verilog design, syntax, or implementation...",
|
225 |
+
lines=2,
|
226 |
+
max_lines=5,
|
227 |
+
scale=4
|
228 |
+
)
|
229 |
+
send_btn = gr.Button("Send", variant="primary", scale=1)
|
230 |
+
|
231 |
+
with gr.Column(scale=1):
|
232 |
+
with gr.Group():
|
233 |
+
stream_btn = gr.Button("📡 Send (Streaming)", variant="secondary", size="sm")
|
234 |
+
clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", size="sm")
|
235 |
+
|
236 |
+
gr.Markdown(
|
237 |
+
"""
|
238 |
+
### 💡 Usage Tips
|
239 |
+
|
240 |
+
**Send**: Quick response (max 1K tokens)
|
241 |
+
**Streaming**: Real-time response (max 2K tokens)
|
242 |
+
|
243 |
+
### ⚡ Optimizations Active
|
244 |
+
- **KV Caching**: Faster token generation
|
245 |
+
- **Memory Management**: Auto cleanup
|
246 |
+
- **Context Limiting**: Recent history only
|
247 |
+
|
248 |
+
### 🎯 Best Practices
|
249 |
+
- Be specific about Verilog requirements
|
250 |
+
- Mention input/output port names
|
251 |
+
- Ask for step-by-step explanations
|
252 |
+
- Clear chat periodically
|
253 |
+
"""
|
254 |
+
)
|
255 |
+
|
256 |
+
# Event handlers for regular send
|
257 |
+
submit_event = msg.submit(
|
258 |
+
fn=generate_response,
|
259 |
+
inputs=[msg, chatbot],
|
260 |
+
outputs=chatbot,
|
261 |
+
show_progress=True
|
262 |
+
).then(
|
263 |
+
lambda: "",
|
264 |
+
inputs=None,
|
265 |
+
outputs=msg
|
266 |
)
|
267 |
|
268 |
+
send_btn.click(
|
269 |
+
fn=generate_response,
|
270 |
+
inputs=[msg, chatbot],
|
271 |
+
outputs=chatbot,
|
272 |
+
show_progress=True
|
273 |
+
).then(
|
274 |
+
lambda: "",
|
275 |
+
inputs=None,
|
276 |
+
outputs=msg
|
277 |
+
)
|
278 |
|
279 |
+
# Event handler for streaming
|
280 |
+
stream_btn.click(
|
281 |
+
fn=generate_response_streaming,
|
282 |
+
inputs=[msg, chatbot],
|
283 |
+
outputs=chatbot,
|
284 |
+
show_progress=True
|
285 |
).then(
|
286 |
+
lambda: "",
|
287 |
+
inputs=None,
|
288 |
outputs=msg
|
289 |
)
|
290 |
|
291 |
+
# Clear chat handler
|
292 |
+
clear_btn.click(
|
293 |
+
fn=clear_chat,
|
294 |
+
inputs=None,
|
295 |
+
outputs=chatbot
|
296 |
+
)
|
297 |
|
298 |
+
# Launch the app
|
299 |
demo.launch(share=True)
|