File size: 4,659 Bytes
6f8934c 9577ec2 6f8934c 9577ec2 6f8934c 69d0c7f 6f8934c 9577ec2 6f8934c 69d0c7f 6f8934c 9577ec2 6f8934c 69d0c7f 6f8934c 69d0c7f 9577ec2 69d0c7f 9577ec2 6f8934c 69d0c7f 6f8934c 9577ec2 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 9577ec2 6f8934c 69d0c7f 9577ec2 69d0c7f 6f8934c 69d0c7f 9577ec2 69d0c7f 6f8934c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Mistral 8B Instruct</h1>
</div>
'''
LICENSE = """
<p/>
---
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Mistral-8B</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_mistral(message: str,
history: list,
temperature: float,
top_p: float,
max_new_tokens: int,
system_prompt: str) -> str:
"""
Generate a streaming response using the Mistral-8B model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
top_p (float): The top-p (nucleus) sampling parameter.
max_new_tokens (int): The maximum number of new tokens to generate.
system_prompt (str): The system prompt to guide the assistant's behavior.
Returns:
str: The generated response.
"""
conversation = []
# Format system prompt correctly using [INST]
if system_prompt:
formatted_prompt = f"[INST] {system_prompt} [/INST]\n\n"
else:
formatted_prompt = ""
# Modify first user message to include system prompt
if history:
first_user_msg = f"{formatted_prompt}{history[0][0]}" if formatted_prompt else history[0][0]
conversation.append({"role": "user", "content": first_user_msg})
conversation.append({"role": "assistant", "content": history[0][1]})
for user, assistant in history[1:]:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
else:
# First message in a new conversation
first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
conversation.append({"role": "user", "content": first_message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Enter system instructions for the model...",
lines=2
)
gr.ChatInterface(
fn=chat_mistral,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
system_prompt_input,
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="Top-p", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
],
examples=[
['Are you a sentient being?']
],
cache_examples=False,
type='messages',
)
if __name__ == "__main__":
demo.launch()
|