File size: 5,740 Bytes
51a7d9e
13880c3
51a7d9e
edb9e8a
13880c3
 
 
 
c8e2710
 
 
13880c3
51a7d9e
5ab9353
 
02ffc17
86de665
 
6886ae0
 
 
 
 
 
 
 
86de665
1854cbf
51a7d9e
c701791
 
 
51a7d9e
1e18916
c8e2710
 
86de665
c701791
c8e2710
13880c3
86de665
e339ee0
32359f6
 
 
 
e339ee0
 
c8e2710
 
d8a8bf1
e339ee0
13880c3
e4c72cc
afeb266
c8e2710
 
86de665
 
 
3738ef6
13880c3
659ca36
c8e2710
86de665
 
 
 
 
 
 
 
 
 
 
 
36ed55f
86de665
 
c701791
 
86de665
 
c701791
 
86de665
c701791
 
 
 
 
 
86de665
c701791
 
86de665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c701791
 
 
 
 
 
 
 
86de665
c701791
 
 
86de665
c701791
 
 
 
 
86de665
c701791
 
 
 
86de665
c701791
 
86de665
 
 
 
 
c701791
 
 
86de665
c701791
 
86de665
 
1854cbf
c701791
 
3738ef6
51a7d9e
c701791
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    TextIteratorStreamer,
    StoppingCriteria,
    StoppingCriteriaList
)

MODEL_ID = "Daemontatox/PathFinderAI-S1"
# MODEL_ID = "Daemontatox/Research_PathfinderAI"

DEFAULT_SYSTEM_PROMPT = """

        Respond in the following format:
[reasoning]
[your reasoning]
[/reasoning]
[answer]
[your answer]
[/answer]
put your final answer within $boxed{}$
"""  # You can modify the default system instructions here

CSS = """
.gr-chatbot { min-height: 500px; border-radius: 15px; }
.special-tag { color: #2ecc71; font-weight: 600; }
footer { display: none !important; }
"""

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Stop when the EOS token is generated.
        return input_ids[0][-1] == tokenizer.eos_token_id

def initialize_model():
    # Enable 4-bit quantization for faster inference and lower memory usage.
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="cuda",
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    model.to("cuda")
    model.eval()  # set evaluation mode to disable gradients and speed up inference

    return model, tokenizer

def format_response(text):
    # List of replacements to format key tokens with HTML for styling.
    replacements = [
        ("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n'),
        ("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n'),
        ("[/Reason]", '\n<strong class="special-tag">[/Reason]</strong>\n'),
        ("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n'),
        ("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n'),
    ]
    for old, new in replacements:
        text = text.replace(old, new)
    return text

@spaces.GPU(duration=120)
def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty):
    # Build the conversation history.
    conversation = [{"role": "system", "content": system_prompt}]
    for user_msg, bot_msg in chat_history:
        conversation.append({"role": "user", "content": user_msg})
        conversation.append({"role": "assistant", "content": bot_msg})
    conversation.append({"role": "user", "content": message})

    # Tokenize the conversation. (This assumes the tokenizer has an apply_chat_template method.)
    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    # Setup the streamer to yield new tokens as they are generated.
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    # Prepare generation parameters including extra customization options.
    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "stopping_criteria": StoppingCriteriaList([StopOnTokens()])
    }

    # Run the generation inside a no_grad block for speed.
    def generate_inference():
        with torch.inference_mode():
            model.generate(**generate_kwargs)
    Thread(target=generate_inference, daemon=True).start()

    # Stream the output tokens.
    partial_message = ""
    new_history = chat_history + [(message, "")]
    for new_token in streamer:
        partial_message += new_token
        formatted = format_response(partial_message)
        new_history[-1] = (message, formatted + "▌")
        yield new_history

    # Final update without the cursor.
    new_history[-1] = (message, format_response(partial_message))
    yield new_history

# Initialize the model and tokenizer globally.
model, tokenizer = initialize_model()

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    <h1 align="center">🧠 AI Reasoning Assistant</h1>
    <p align="center">Ask me hard questions and see the reasoning unfold.</p>
    """)
    
    chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
    msg = gr.Textbox(label="Your Question", placeholder="Type your question...")

    with gr.Accordion("⚙️ Settings", open=False):
        system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
        temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)")
        max_tokens = gr.Slider(128, 8192, 4096, label="Max Response Length")
        top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)")
        top_k = gr.Slider(0, 100, value=50, label="Top K")
        repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty")

    clear = gr.Button("Clear History")
    
    # Link the input textbox with the generation function.
    msg.submit(
        generate_response,
        [msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty],
        chatbot,
        show_progress=True
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue().launch()