import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B" device = "cuda:0" if torch.cuda.is_available() else "cpu" # Try loading the model with explicit error handling try: veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path) veri_model = AutoModelForCausalLM.from_pretrained( veri_model_path, device_map="auto", torch_dtype="auto", trust_remote_code=True, use_cache=True, # Enable KV caching # attn_implementation="flash_attention_2" if torch.cuda.is_available() else None ) except Exception as e: print(f"Model loading error: {e}") veri_model = None veri_tokenizer = None @spaces.GPU(duration=60) def truncate_at_code_end(text): """Truncate text at 'CODE END' to remove repetitive content""" if "CODE END" in text: end_index = text.find("CODE END") + len("CODE END") return text[:end_index].strip() return text.strip() def generate_response(user_message, history): if not veri_model or not veri_tokenizer: return history + [["Error", "Model not loaded properly"]] if not user_message.strip(): return history # Simple generation without streaming first system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END." conversation = f"System: {system_message}\n" recent_history = history[-3:] if len(history) > 3 else history for h in recent_history: conversation += f"User: {h[0]}\nAssistant: {h[1]}\n" conversation += f"User: {user_message}\nAssistant:" inputs = veri_tokenizer( conversation, return_tensors="pt", truncation=True, max_length=8192, # padding=True, # return_attention_mask=True ).to(device) with torch.no_grad(): outputs = veri_model.generate( **inputs, max_new_tokens=20000, temperature=0.6, top_p=0.95, do_sample=True, frequency_penalty = 0, presence_penalty = 0 # top_k=50, # Top-k sampling for efficiency # pad_token_id=veri_tokenizer.eos_token_id, # eos_token_id=veri_tokenizer.eos_token_id, use_cache=True, # Enable KV caching for faster generation repetition_penalty=1.1, # Reduce repetition # length_penalty=1.0, # early_stopping=True, # Stop early when appropriate # num_beams=1, # Greedy search for speed # pad_token_id=veri_tokenizer.eos_token_id ) response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) # Truncate at CODE END to remove repetitive content # response = truncate_at_code_end(response) if torch.cuda.is_available(): torch.cuda.empty_cache() # Return updated history return history + [[user_message, response.strip()]] # Create minimal interface with gr.Blocks( title="VeriThoughts-7B Chatbot", css=""" .gradio-container { max-width: 1200px !important; } .chat-message { font-size: 14px; } """ ) as demo: gr.Markdown( """ # 🤖 VeriThoughts-7B Chatbot An AI assistant specialized in Verilog coding and digital design. **Tips for better results:** - Mention input/output port names clearly - Ask for step-by-step explanations """ ) chatbot = gr.Chatbot(value=[], label="Chat") msg = gr.Textbox(label="Your message", placeholder="Ask me about Verilog design, syntax, or implementation...") clear = gr.Button("Clear") # Simple event handling msg.submit( fn=generate_response, inputs=[msg, chatbot], outputs=chatbot ).then( lambda: "", inputs=None, outputs=msg ) clear.click(lambda: [], outputs=chatbot) # Launch without ssr_mode parameter which might cause issues demo.launch(share=True)