import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B" device = "cuda:0" if torch.cuda.is_available() else "cpu" # Try loading the model with KV caching (no flash attention or quantization) try: print("Loading tokenizer...") veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path) # Set pad token if not exists if veri_tokenizer.pad_token is None: veri_tokenizer.pad_token = veri_tokenizer.eos_token print("Loading model with KV caching...") veri_model = AutoModelForCausalLM.from_pretrained( veri_model_path, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True, use_cache=True, # Enable KV caching for faster generation low_cpu_mem_usage=True ) print("Model loaded successfully with KV caching!") except Exception as e: print(f"Model loading error: {e}") veri_model = None veri_tokenizer = None @spaces.GPU(duration=60) def truncate_at_code_end(text): """Truncate text at 'CODE END' to remove repetitive content""" if "CODE END" in text: end_index = text.find("CODE END") + len("CODE END") return text[:end_index].strip() return text.strip() def generate_response(user_message, history): """Non-streaming generation for quick responses""" if not veri_model or not veri_tokenizer: return history + [["Error", "Model not loaded properly"]] if not user_message.strip(): return history system_message = "You are VeriThoughts, a helpful assistant that thinks step by step to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END." # Create conversation history (limit to last 3 exchanges for memory efficiency) conversation = f"System: {system_message}\n" recent_history = history[-3:] if len(history) > 3 else history for h in recent_history: conversation += f"User: {h[0]}\nAssistant: {h[1]}\n" conversation += f"User: {user_message}\nAssistant:" # Tokenize input inputs = veri_tokenizer( conversation, return_tensors="pt", truncation=True, max_length=4096, padding=True ).to(device) # Generate with KV caching with torch.no_grad(): outputs = veri_model.generate( **inputs, max_new_tokens=1024, temperature=0.6, top_p=0.95, do_sample=True, pad_token_id=veri_tokenizer.pad_token_id, eos_token_id=veri_tokenizer.eos_token_id, use_cache=True, # KV caching for speed repetition_penalty=1.1, early_stopping=True ) # Decode response response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) # Truncate at CODE END to remove repetitive content response = truncate_at_code_end(response) # Clean up GPU memory if torch.cuda.is_available(): torch.cuda.empty_cache() return history + [[user_message, response]] @spaces.GPU(duration=120) def generate_response_streaming(user_message, history): """Streaming generation for real-time response display""" if not veri_model or not veri_tokenizer: yield history + [["Error", "Model not loaded properly"]] return if not user_message.strip(): yield history return system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud, to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END." # Create conversation history (limit for memory efficiency) conversation = f"System: {system_message}\n" recent_history = history[-3:] if len(history) > 3 else history for h in recent_history: conversation += f"User: {h[0]}\nAssistant: {h[1]}\n" conversation += f"User: {user_message}\nAssistant:" try: # Tokenize input inputs = veri_tokenizer( conversation, return_tensors="pt", truncation=True, max_length=2048, padding=True ).to(device) # Setup streaming streamer = TextIteratorStreamer( veri_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30.0 ) # Generation parameters with KV caching generation_kwargs = { **inputs, "max_new_tokens": 4096, "temperature": 0.6, "top_p": 0.95, "do_sample": True, "pad_token_id": veri_tokenizer.pad_token_id, "eos_token_id": veri_tokenizer.eos_token_id, "use_cache": True, # KV caching for faster streaming "repetition_penalty": 1.1, "streamer": streamer, "early_stopping": True } # Start generation in a separate thread thread = Thread(target=veri_model.generate, kwargs=generation_kwargs) thread.start() # Stream the response token by token generated_text = "" new_history = history + [[user_message, ""]] code_end_reached = False for new_text in streamer: # Stop streaming if we've already reached CODE END if code_end_reached: break generated_text += new_text # Check if CODE END appears in the generated text if "CODE END" in generated_text: # Truncate at CODE END and mark as complete generated_text = truncate_at_code_end(generated_text) code_end_reached = True new_history[-1][1] = generated_text yield new_history # Break early if CODE END was reached if code_end_reached: break # Ensure the thread completes thread.join() # Final cleanup in case CODE END wasn't reached during streaming if not code_end_reached: final_text = truncate_at_code_end(generated_text) new_history[-1][1] = final_text yield new_history except Exception as e: print(f"Streaming error: {e}") error_history = history + [[user_message, f"Streaming error: {str(e)}"]] yield error_history finally: # Clean up GPU memory after generation if torch.cuda.is_available(): torch.cuda.empty_cache() def clear_chat(): """Clear chat and clean up memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() return [] # Create interface with soft theme with gr.Blocks(title="VeriThoughts-7B Chatbot") as demo: gr.Markdown("# VeriThoughts-7B Chatbot") gr.Markdown("*Optimized with KV caching for faster generation*") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( value=[], label="Chat", height=600, show_label=False, container=True ) with gr.Row(): msg = gr.Textbox( label="Your message", placeholder="Ask me about Verilog design, syntax, or implementation...", lines=2, max_lines=5, scale=4 ) send_btn = gr.Button("Send", variant="primary", scale=1) with gr.Column(scale=1): with gr.Group(): stream_btn = gr.Button("📡 Send (Streaming)", variant="secondary", size="sm") clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", size="sm") gr.Markdown( """ ### 💡 Usage Tips **Send**: Quick response (max 1K tokens) **Streaming**: Real-time response (max 2K tokens) ### ⚡ Optimizations Active - **KV Caching**: Faster token generation - **Memory Management**: Auto cleanup - **Context Limiting**: Recent history only ### 🎯 Best Practices - Be specific about Verilog requirements - Mention input/output port names - Ask for step-by-step explanations - Clear chat periodically """ ) # Event handlers for regular send submit_event = msg.submit( fn=generate_response, inputs=[msg, chatbot], outputs=chatbot, show_progress=True ).then( lambda: "", inputs=None, outputs=msg ) send_btn.click( fn=generate_response, inputs=[msg, chatbot], outputs=chatbot, show_progress=True ).then( lambda: "", inputs=None, outputs=msg ) # Event handler for streaming stream_btn.click( fn=generate_response_streaming, inputs=[msg, chatbot], outputs=chatbot, show_progress=True ).then( lambda: "", inputs=None, outputs=msg ) # Clear chat handler clear_btn.click( fn=clear_chat, inputs=None, outputs=chatbot ) # Launch the app demo.launch(share=True)