Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Try loading the model with explicit error handling | |
try: | |
veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path) | |
veri_model = AutoModelForCausalLM.from_pretrained( | |
veri_model_path, | |
device_map="auto", | |
torch_dtype="auto", | |
trust_remote_code=True, | |
use_cache=True, # Enable KV caching | |
# attn_implementation="flash_attention_2" if torch.cuda.is_available() else None | |
) | |
except Exception as e: | |
print(f"Model loading error: {e}") | |
veri_model = None | |
veri_tokenizer = None | |
def truncate_at_code_end(text): | |
"""Truncate text at 'CODE END' to remove repetitive content""" | |
if "CODE END" in text: | |
end_index = text.find("CODE END") + len("CODE END") | |
return text[:end_index].strip() | |
return text.strip() | |
def generate_response(user_message, history): | |
if not veri_model or not veri_tokenizer: | |
return history + [["Error", "Model not loaded properly"]] | |
if not user_message.strip(): | |
return history | |
# Simple generation without streaming first | |
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END." | |
conversation = f"System: {system_message}\n" | |
recent_history = history[-3:] if len(history) > 3 else history | |
for h in recent_history: | |
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n" | |
conversation += f"User: {user_message}\nAssistant:" | |
inputs = veri_tokenizer( | |
conversation, | |
return_tensors="pt", | |
truncation=True, | |
max_length=8192, | |
# padding=True, | |
# return_attention_mask=True | |
).to(device) | |
with torch.no_grad(): | |
outputs = veri_model.generate( | |
**inputs, | |
max_new_tokens=4096, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True, | |
top_k=50, # Top-k sampling for efficiency | |
# pad_token_id=veri_tokenizer.eos_token_id, | |
# eos_token_id=veri_tokenizer.eos_token_id, | |
use_cache=True, # Enable KV caching for faster generation | |
repetition_penalty=1.1, # Reduce repetition | |
length_penalty=1.0, | |
early_stopping=True, # Stop early when appropriate | |
num_beams=1, # Greedy search for speed | |
pad_token_id=veri_tokenizer.eos_token_id | |
) | |
response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
# Truncate at CODE END to remove repetitive content | |
response = truncate_at_code_end(response) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Return updated history | |
return history + [[user_message, response.strip()]] | |
# Create minimal interface | |
with gr.Blocks( | |
title="VeriThoughts-7B Chatbot", | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.chat-message { | |
font-size: 14px; | |
} | |
""" | |
) as demo: | |
gr.Markdown( | |
""" | |
# 🤖 VeriThoughts-7B Chatbot | |
An AI assistant specialized in Verilog coding and digital design. | |
**Tips for better results:** | |
- Mention input/output port names clearly | |
- Ask for step-by-step explanations | |
""" | |
) | |
chatbot = gr.Chatbot(value=[], label="Chat") | |
msg = gr.Textbox(label="Your message", placeholder="Ask me about Verilog design, syntax, or implementation...") | |
clear = gr.Button("Clear") | |
# Simple event handling | |
msg.submit( | |
fn=generate_response, | |
inputs=[msg, chatbot], | |
outputs=chatbot | |
).then( | |
lambda: "", | |
inputs=None, | |
outputs=msg | |
) | |
clear.click(lambda: [], outputs=chatbot) | |
# Launch without ssr_mode parameter which might cause issues | |
demo.launch(share=True) |