File size: 4,497 Bytes
a59cdce
 
 
 
13a0ef1
a59cdce
 
 
 
 
 
 
13a0ef1
 
a59cdce
 
ecd62fd
a59cdce
 
 
13a0ef1
a59cdce
 
 
 
 
 
 
13a0ef1
a59cdce
 
 
 
 
 
 
 
 
 
 
13a0ef1
a59cdce
 
 
 
13a0ef1
 
a59cdce
 
 
 
 
 
 
 
13a0ef1
a59cdce
13a0ef1
 
 
 
ecd62fd
13a0ef1
a59cdce
13a0ef1
 
a59cdce
13a0ef1
a59cdce
13a0ef1
 
 
 
 
 
a59cdce
13a0ef1
a59cdce
 
 
 
 
 
 
 
 
 
13a0ef1
a59cdce
13a0ef1
a59cdce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a0ef1
a59cdce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a14d3d
 
 
 
a59cdce
 
 
 
 
 
f3eac83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread

MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID")

TITLE = "<h1><center>MathΣtral</center></h1>"

PLACEHOLDER = """
<center>
<p>MathΣtral - I'm MisMath,Your Math advisor</p>
</center>
"""


CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""

device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    ignore_mismatched_sizes=True)

@spaces.GPU()
def stream_chat(
    message: str, 
    history: list, 
    temperature: float = 0.3, 
    max_new_tokens: int = 1024, 
    top_p: float = 1.0, 
    top_k: int = 20, 
    penalty: float = 1.2,
):
    print(f'message: {message}')
    print(f'history: {history}')

    conversation = []
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])

    conversation.append({"role": "user", "content": message})

    input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=inputs, 
        max_new_tokens = max_new_tokens,
        do_sample = False if temperature == 0 else True,
        top_p = top_p,
        top_k = top_k,
        temperature = temperature,
        streamer=streamer,
        pad_token_id = 10,
    )

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()
        
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

            
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)

footer = """
<div style="text-align: center; margin-top: 20px;">
    <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
    <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
    <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a>
    <br>
    Made with 💖 by Pejman Ebrahimi
</div>
"""

with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
    gr.HTML(TITLE)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.3,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=8192,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.2,
                label="Repetition penalty",
                render=False,
            ),
        ],
        examples=[
            ["Can you explain the Pythagorean theorem?"],
            ["What is the derivative of sin(x)?"],
            ["Solve the integral of e^(2x) dx."],
            ["How does quantum entanglement work?"],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()