File size: 4,555 Bytes
e10c195
89628dd
 
 
 
e10c195
b5db033
e10c195
b5db033
e10c195
b5db033
 
 
ab95480
 
 
 
 
 
 
f9a267b
ab95480
 
b5db033
 
 
 
eb2ed76
 
f486321
 
 
 
 
 
 
b5db033
eb2ed76
b5db033
eb2ed76
 
 
b5db033
 
ab95480
9e21282
b5db033
ab95480
 
 
b5db033
 
9e21282
ab95480
 
 
 
fb7c6fd
f9a267b
 
ab95480
b5db033
 
 
 
d811871
 
b5db033
 
d811871
 
 
ab95480
 
11be9fe
 
d811871
 
 
 
9e21282
 
b5db033
ab95480
f486321
d811871
f486321
 
ab95480
 
9e21282
b5db033
 
eb2ed76
b5db033
eb2ed76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5db033
eb2ed76
 
 
b5db033
 
 
eb2ed76
b5db033
 
 
 
 
eb2ed76
b5db033
 
eb2ed76
 
 
b5db033
89628dd
b5db033
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread

veri_model_path = "nyu-dice-lab/VeriThoughts-Reasoning-7B"

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Try loading the model with explicit error handling
try:
    veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)

    veri_model = AutoModelForCausalLM.from_pretrained(
        veri_model_path, 
        device_map="auto", 
        torch_dtype="auto",
        trust_remote_code=True,
        use_cache=True,  # Enable KV caching
 #       attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
    )

except Exception as e:
    print(f"Model loading error: {e}")
    veri_model = None
    veri_tokenizer = None

@spaces.GPU(duration=60)
def truncate_at_code_end(text):
    """Truncate text at 'CODE END' to remove repetitive content"""
    if "CODE END" in text:
        end_index = text.find("CODE END") + len("CODE END")
        return text[:end_index].strip()
    return text.strip()
    
def generate_response(user_message, history):
    if not veri_model or not veri_tokenizer:
        return history + [["Error", "Model not loaded properly"]]
    
    if not user_message.strip():
        return history
        
    # Simple generation without streaming first
    system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END."
    
    conversation = f"System: {system_message}\n"
    recent_history = history[-3:] if len(history) > 3 else history
  
    for h in recent_history:
        conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
    conversation += f"User: {user_message}\nAssistant:"
    
    inputs = veri_tokenizer(
        conversation,                     
        return_tensors="pt", 
        truncation=True, 
        max_length=8192,
 #       padding=True,
 #       return_attention_mask=True
    ).to(device)
    
    with torch.no_grad():
        outputs = veri_model.generate(
            **inputs,
            max_new_tokens=20000,
            temperature=0.6,
            top_p=0.95,
            do_sample=True,
            frequency_penalty = 0,
            presence_penalty = 0
 #           top_k=50,            # Top-k sampling for efficiency
 #           pad_token_id=veri_tokenizer.eos_token_id,
 #           eos_token_id=veri_tokenizer.eos_token_id,
            use_cache=True,      # Enable KV caching for faster generation
            repetition_penalty=1.1,  # Reduce repetition
 #           length_penalty=1.0,
 #           early_stopping=True,     # Stop early when appropriate
 #           num_beams=1,            # Greedy search for speed
 #           pad_token_id=veri_tokenizer.eos_token_id
        )
    
    response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

    # Truncate at CODE END to remove repetitive content
    # response = truncate_at_code_end(response)
    

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Return updated history
    return history + [[user_message, response.strip()]]

# Create minimal interface
with gr.Blocks(
    title="VeriThoughts-7B Chatbot",
    css="""
    .gradio-container {
        max-width: 1200px !important;
    }
    .chat-message {
        font-size: 14px;
    }
    """
) as demo:
    gr.Markdown(
        """
        # 🤖 VeriThoughts-7B Chatbot
        
        An AI assistant specialized in Verilog coding and digital design.
        
        **Tips for better results:**
        - Mention input/output port names clearly
        - Ask for step-by-step explanations
        """
    )
    
    chatbot = gr.Chatbot(value=[], label="Chat")
    msg = gr.Textbox(label="Your message", placeholder="Ask me about Verilog design, syntax, or implementation...")
    clear = gr.Button("Clear")
    
    # Simple event handling
    msg.submit(
        fn=generate_response, 
        inputs=[msg, chatbot], 
        outputs=chatbot
    ).then(
        lambda: "", 
        inputs=None, 
        outputs=msg
    )
    
    clear.click(lambda: [], outputs=chatbot)

# Launch without ssr_mode parameter which might cause issues
demo.launch(share=True)