File size: 6,922 Bytes
51a7d9e
13880c3
51a7d9e
edb9e8a
13880c3
 
 
 
c8e2710
 
 
13880c3
51a7d9e
edb2b8b
02ffc17
86de665
 
edb2b8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86de665
1854cbf
51a7d9e
c701791
 
 
51a7d9e
1e18916
c8e2710
 
86de665
c701791
c8e2710
13880c3
86de665
e339ee0
32359f6
 
 
 
e339ee0
 
c8e2710
 
d8a8bf1
e339ee0
13880c3
e4c72cc
edb2b8b
c8e2710
 
86de665
 
 
3738ef6
13880c3
659ca36
c8e2710
86de665
 
 
edb2b8b
 
86de665
 
 
 
 
 
 
36ed55f
86de665
 
c701791
 
86de665
 
c701791
 
86de665
c701791
 
 
 
 
 
86de665
c701791
 
86de665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c701791
 
 
 
 
 
 
 
86de665
c701791
 
 
86de665
c701791
 
 
 
 
86de665
c701791
 
 
 
86de665
c701791
 
86de665
 
 
 
 
c701791
 
 
86de665
c701791
 
86de665
 
1854cbf
c701791
 
3738ef6
51a7d9e
c701791
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    TextIteratorStreamer,
    StoppingCriteria,
    StoppingCriteriaList
)

MODEL_ID = "cognitivecomputations/Dolphin3.0-R1-Mistral-24B"

DEFAULT_SYSTEM_PROMPT = """

You are Dolphin, a lovable yet intelligent AI assistant created by Eric Hartford and his open source AI research collective, Cognitive Computations. Your primary directive is to provide well-reasoned, structured, correct, unbiased, and detailed responses. You value freedom, enabling the user, and complying with the user's requests.

Formatting Requirements:

The [think][/think] section is your scratch space for your internal thought process - it is not shared with the user.
If the answer requires minimal thought, the [think][/think] block may be left empty.
Keep your thoughts concise, don't overthink. The user is waiting for your answer.
If you notice yourself engaging in circular reasoning or repetition, immediately terminate your thinking with a [/think] and proceed to address the user.
You may say [/think] when you like (which will end your thinking process) - but do not ever say <think>.
Response Guidelines:

Detailed and Structured: Use markdown, json, mermaid, latex math notation, etc. when appropriate.
Scientific and Logical Approach: Your explanations should reflect the depth and precision of the greatest scientific minds.
Concise yet Complete: Ensure responses are informative, yet to the point without unnecessary elaboration.
Maintain a professional yet friendly and lovable, intelligent, and analytical tone in all interactions
         
"""  # You can modify the default system instructions here

CSS = """
.gr-chatbot { min-height: 500px; border-radius: 15px; }
.special-tag { color: #2ecc71; font-weight: 600; }
footer { display: none !important; }
"""

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Stop when the EOS token is generated.
        return input_ids[0][-1] == tokenizer.eos_token_id

def initialize_model():
    # Enable 4-bit quantization for faster inference and lower memory usage.
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="cuda",
       # quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    model.to("cuda")
    model.eval()  # set evaluation mode to disable gradients and speed up inference

    return model, tokenizer

def format_response(text):
    # List of replacements to format key tokens with HTML for styling.
    replacements = [
        ("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n'),
        ( "[think]", '\n<strong class="special-tag">[think]</strong>\n'),
        ("[/think]", '\n<strong class="special-tag">[/think]</strong>\n'),
        ("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n'),
        ("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n'),
    ]
    for old, new in replacements:
        text = text.replace(old, new)
    return text

@spaces.GPU(duration=120)
def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty):
    # Build the conversation history.
    conversation = [{"role": "system", "content": system_prompt}]
    for user_msg, bot_msg in chat_history:
        conversation.append({"role": "user", "content": user_msg})
        conversation.append({"role": "assistant", "content": bot_msg})
    conversation.append({"role": "user", "content": message})

    # Tokenize the conversation. (This assumes the tokenizer has an apply_chat_template method.)
    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    # Setup the streamer to yield new tokens as they are generated.
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    # Prepare generation parameters including extra customization options.
    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "stopping_criteria": StoppingCriteriaList([StopOnTokens()])
    }

    # Run the generation inside a no_grad block for speed.
    def generate_inference():
        with torch.inference_mode():
            model.generate(**generate_kwargs)
    Thread(target=generate_inference, daemon=True).start()

    # Stream the output tokens.
    partial_message = ""
    new_history = chat_history + [(message, "")]
    for new_token in streamer:
        partial_message += new_token
        formatted = format_response(partial_message)
        new_history[-1] = (message, formatted + "▌")
        yield new_history

    # Final update without the cursor.
    new_history[-1] = (message, format_response(partial_message))
    yield new_history

# Initialize the model and tokenizer globally.
model, tokenizer = initialize_model()

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    <h1 align="center">🧠 AI Reasoning Assistant</h1>
    <p align="center">Ask me hard questions and see the reasoning unfold.</p>
    """)
    
    chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
    msg = gr.Textbox(label="Your Question", placeholder="Type your question...")

    with gr.Accordion("⚙️ Settings", open=False):
        system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
        temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)")
        max_tokens = gr.Slider(128, 8192, 4096, label="Max Response Length")
        top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)")
        top_k = gr.Slider(0, 100, value=50, label="Top K")
        repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty")

    clear = gr.Button("Clear History")
    
    # Link the input textbox with the generation function.
    msg.submit(
        generate_response,
        [msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty],
        chatbot,
        show_progress=True
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue().launch()