Spaces:
Runtime error
Runtime error
File size: 5,738 Bytes
a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 ecd62fd dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 13a0ef1 dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 a59cdce 5a14d3d a59cdce dcf5029 a59cdce dcf5029 a59cdce dcf5029 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
# Define constants and configuration
MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID")
PLACEHOLDER = """
<center>
<p>MathΣtral - Your Math advisor</p>
<p>Hi! I'm MisMath. A Math advisor. My model is based on mathstral-7B-v0.1. Feel free to ask your questions</p>
<p>Mathstral 7B is a model specializing in mathematical and scientific tasks, based on Mistral 7B.</p>
<p>mathstral-7B-v0.1 is the first Mathstral model</p>
<img src="Mistral.png" alt="MathStral Model" style="width:300px;height:200px;">
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h1 {
text-align: center;
font-size: 2em;
color: #333;
}
"""
TITLE = "<h1><center>MathΣtral - Your Math advisor</center></h1>"
device = "cuda" # for GPU usage or "cpu" for CPU usage
# Configuration for model quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config
)
# Define the chat streaming function
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.8,
max_new_tokens: int = 1024,
top_p: float = 1.0,
top_k: int = 20,
penalty: float = 1.2,
):
print(f'message: {message}')
print(f'history: {history}')
# Prepare the conversation context
conversation_text = system_prompt + "\n"
for prompt, answer in history:
conversation_text += f"User: {prompt}\nAssistant: {answer}\n"
conversation_text += f"User: {message}\nAssistant:"
# Tokenize the conversation text
input_ids = tokenizer(conversation_text, return_tensors="pt").input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001, 128008, 128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
# Clean the buffer to remove unwanted prefixes
cleaned_text = buffer.split("Assistant:")[-1].strip()
yield cleaned_text
# Define the Gradio chatbot component
chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER)
# Define the footer with links
footer = """
<div style="text-align: center; margin-top: 20px;">
<a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
<a href="https://github.com/arad1367" target="_blank">GitHub</a> |
<a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a>
<br>
Made with 💖 by Pejman Ebrahimi
</div>
"""
# Create and launch the Gradio interface
with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(
value="You are a helpful assistant for Math questions and complex calculations and programming and your name is MisMath",
label="System Prompt",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition penalty",
render=False,
),
],
examples=[
["Can you explain the Pythagorean theorem?"],
["What is the derivative of sin(x)?"],
["Solve the integral of e^(2x) dx."],
["How does quantum entanglement work?"],
],
cache_examples=False,
)
gr.HTML(footer)
# Launch the application
if __name__ == "__main__":
demo.launch()
|