File size: 6,879 Bytes
538f502
 
 
51a7d9e
13880c3
51a7d9e
e4f0261
edb9e8a
13880c3
 
 
 
c8e2710
 
 
13880c3
51a7d9e
94eaec5
02ffc17
e585255
 
 
 
 
 
 
 
 
 
 
1854cbf
51a7d9e
c701791
 
 
51a7d9e
1e18916
c8e2710
 
86de665
c701791
c8e2710
13880c3
86de665
e339ee0
32359f6
 
 
 
e339ee0
 
c8e2710
 
d8a8bf1
e339ee0
13880c3
e4c72cc
29af8ca
c8e2710
e4f0261
 
86de665
 
 
3738ef6
13880c3
659ca36
c8e2710
86de665
 
 
43c94de
edb2b8b
86de665
 
 
 
 
 
 
43c94de
 
 
 
 
 
 
 
 
 
 
 
 
 
29af8ca
43c94de
 
 
 
7bf6caa
86de665
 
c701791
 
86de665
 
c701791
 
43c94de
 
 
c701791
86de665
c701791
 
86de665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c701791
 
 
 
 
 
 
 
86de665
c701791
 
 
86de665
c701791
 
 
 
 
86de665
c701791
 
 
 
86de665
c701791
 
0a1ecda
a9f51f4
 
0a1ecda
ce20128
c701791
 
 
86de665
c701791
 
86de665
 
1854cbf
c701791
 
3738ef6
51a7d9e
c701791
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

import torch
import spaces
import gradio as gr
import flash_attn
from threading import Thread
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    TextIteratorStreamer,
    StoppingCriteria,
    StoppingCriteriaList
)

MODEL_ID = "unsloth/QwQ-32B-unsloth-bnb-4bit"

DEFAULT_SYSTEM_PROMPT = """
Think step by step and explain your reasoning clearly. Break down the problem into logical components, verify each step, and ensure consistency before arriving at the final answer."

For complex reasoning tasks, you can enhance it with:

"If there are multiple possible solutions, consider each one before selecting the best answer."

"Use intermediate calculations and justify each step before proceeding."

"If relevant, include real-world analogies to improve clarity.
"""

CSS = """
.gr-chatbot { min-height: 500px; border-radius: 15px; }
.special-tag { color: #2ecc71; font-weight: 600; }
footer { display: none !important; }
"""

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Stop when the EOS token is generated.
        return input_ids[0][-1] == tokenizer.eos_token_id

def initialize_model():
    # Enable 4-bit quantization for faster inference and lower memory usage.
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="cuda",
        #quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
    )
    model.to("cuda")
    model.eval()  # set evaluation mode to disable gradients and speed up inference

    return model, tokenizer

def format_response(text):
    # List of replacements to format key tokens with HTML for styling.
    replacements = [
        ("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n'),
        ("[think]", '\n<strong class="special-tag">[think]</strong>\n'),
        ("[/think]", '\n<strong class="special-tag">[/think]</strong>\n'),
        ("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n'),
        ("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n'),
    ]
    for old, new in replacements:
        text = text.replace(old, new)
    return text

# --- New helper: Llama-3 conversation template ---
def apply_llama3_chat_template(conversation, add_generation_prompt=True):
    """
    Convert the conversation (a list of dicts with 'role' and 'content') 
    into a single prompt string in Llama-3 style.
    """
    prompt = ""
    for msg in conversation:
        role = msg["role"].upper()
        if role == "SYSTEM":
            prompt += "<|SYSTEM|>\n" + msg["content"].strip() + "\n"
        elif role == "USER":
            prompt += "<|USER|>\n" + msg["content"].strip() + "\n"
        elif role == "ASSISTANT":
            prompt += "<|ASSISTANT|>\n" + msg["content"].strip() + "<think>\n"
    if add_generation_prompt:
        prompt += "<|ASSISTANT|>\n"
    return prompt

@spaces.GPU(duration=120)
def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty):
    # Build the conversation history.
    conversation = [{"role": "system", "content": system_prompt}]
    for user_msg, bot_msg in chat_history:
        conversation.append({"role": "user", "content": user_msg})
        conversation.append({"role": "assistant", "content": bot_msg})
    conversation.append({"role": "user", "content": message})

    # Use the Llama-3 conversation template to build the prompt.
    prompt = apply_llama3_chat_template(conversation, add_generation_prompt=True)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    # Setup the streamer to yield new tokens as they are generated.
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    # Prepare generation parameters including extra customization options.
    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "stopping_criteria": StoppingCriteriaList([StopOnTokens()])
    }

    # Run the generation inside a no_grad block for speed.
    def generate_inference():
        with torch.inference_mode():
            model.generate(**generate_kwargs)
    Thread(target=generate_inference, daemon=True).start()

    # Stream the output tokens.
    partial_message = ""
    new_history = chat_history + [(message, "")]
    for new_token in streamer:
        partial_message += new_token
        formatted = format_response(partial_message)
        new_history[-1] = (message, formatted + "▌")
        yield new_history

    # Final update without the cursor.
    new_history[-1] = (message, format_response(partial_message))
    yield new_history

# Initialize the model and tokenizer globally.
model, tokenizer = initialize_model()

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    <h1 align="center">🧠 AI Reasoning Assistant</h1>
    <p align="center">Ask me hard questions and see the reasoning unfold.</p>
    """)
    
    chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
    msg = gr.Textbox(label="Your Question", placeholder="Type your question...")

    with gr.Accordion("⚙️ Settings", open=False):
        system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
        temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)")
        max_tokens = gr.Slider(128, 32768, 32768, label="Max Response Length")
        top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)")
        top_k = gr.Slider(0, 100, value=35, label="Top K")
        repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty")

    clear = gr.Button("Clear History")
    
    # Link the input textbox with the generation function.
    msg.submit(
        generate_response,
        [msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty],
        chatbot,
        show_progress=True
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue().launch()