chegde's picture
Update app.py
11be9fe verified
raw
history blame
4.56 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Try loading the model with explicit error handling
try:
veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
veri_model = AutoModelForCausalLM.from_pretrained(
veri_model_path,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True,
use_cache=True, # Enable KV caching
# attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
except Exception as e:
print(f"Model loading error: {e}")
veri_model = None
veri_tokenizer = None
@spaces.GPU(duration=60)
def truncate_at_code_end(text):
"""Truncate text at 'CODE END' to remove repetitive content"""
if "CODE END" in text:
end_index = text.find("CODE END") + len("CODE END")
return text[:end_index].strip()
return text.strip()
def generate_response(user_message, history):
if not veri_model or not veri_tokenizer:
return history + [["Error", "Model not loaded properly"]]
if not user_message.strip():
return history
# Simple generation without streaming first
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END."
conversation = f"System: {system_message}\n"
recent_history = history[-3:] if len(history) > 3 else history
for h in recent_history:
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
conversation += f"User: {user_message}\nAssistant:"
inputs = veri_tokenizer(
conversation,
return_tensors="pt",
truncation=True,
max_length=8192,
# padding=True,
# return_attention_mask=True
).to(device)
with torch.no_grad():
outputs = veri_model.generate(
**inputs,
max_new_tokens=20000,
temperature=0.6,
top_p=0.95,
do_sample=True,
frequency_penalty = 0,
presence_penalty = 0
# top_k=50, # Top-k sampling for efficiency
# pad_token_id=veri_tokenizer.eos_token_id,
# eos_token_id=veri_tokenizer.eos_token_id,
use_cache=True, # Enable KV caching for faster generation
repetition_penalty=1.1, # Reduce repetition
# length_penalty=1.0,
# early_stopping=True, # Stop early when appropriate
# num_beams=1, # Greedy search for speed
# pad_token_id=veri_tokenizer.eos_token_id
)
response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Truncate at CODE END to remove repetitive content
# response = truncate_at_code_end(response)
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Return updated history
return history + [[user_message, response.strip()]]
# Create minimal interface
with gr.Blocks(
title="VeriThoughts-7B Chatbot",
css="""
.gradio-container {
max-width: 1200px !important;
}
.chat-message {
font-size: 14px;
}
"""
) as demo:
gr.Markdown(
"""
# 🤖 VeriThoughts-7B Chatbot
An AI assistant specialized in Verilog coding and digital design.
**Tips for better results:**
- Mention input/output port names clearly
- Ask for step-by-step explanations
"""
)
chatbot = gr.Chatbot(value=[], label="Chat")
msg = gr.Textbox(label="Your message", placeholder="Ask me about Verilog design, syntax, or implementation...")
clear = gr.Button("Clear")
# Simple event handling
msg.submit(
fn=generate_response,
inputs=[msg, chatbot],
outputs=chatbot
).then(
lambda: "",
inputs=None,
outputs=msg
)
clear.click(lambda: [], outputs=chatbot)
# Launch without ssr_mode parameter which might cause issues
demo.launch(share=True)