File size: 5,388 Bytes
6f8934c f1d7efb 6f8934c 9577ec2 6f8934c 9577ec2 6f8934c 69d0c7f f1d7efb 6f8934c 9577ec2 6f8934c 69d0c7f 6f8934c 9577ec2 6f8934c 69d0c7f 6f8934c 69d0c7f 9577ec2 69d0c7f 9577ec2 6f8934c f1d7efb 6f8934c 69d0c7f f1d7efb 6f8934c 9577ec2 f1d7efb 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 9577ec2 6f8934c 69d0c7f 9577ec2 69d0c7f 6f8934c f1d7efb 69d0c7f f1d7efb 69d0c7f 6f8934c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Mistral 8B Instruct</h1>
</div>
'''
LICENSE = """
<p/>
---
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Mistral-8B</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mistralai/Ministral-8B-Instruct-2410")
model = AutoModelForCausalLM.from_pretrained("mistralai/Ministral-8B-Instruct-2410", device_map="auto")
# Ensure we have a pad token
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_mistral(message: str,
history: list,
temperature: float,
top_p: float,
max_new_tokens: int,
system_prompt: str) -> str:
"""
Generate a streaming response using the Mistral-8B model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
top_p (float): The top-p (nucleus) sampling parameter.
max_new_tokens (int): The maximum number of new tokens to generate.
system_prompt (str): The system prompt to guide the assistant's behavior.
Returns:
str: The generated response.
"""
conversation = []
# Format system prompt correctly using [INST]
if system_prompt:
formatted_prompt = f"[INST] {system_prompt} [/INST]\n\n"
else:
formatted_prompt = ""
# Modify first user message to include system prompt
if history:
first_user_msg = f"{formatted_prompt}{history[0][0]}" if formatted_prompt else history[0][0]
conversation.append({"role": "user", "content": first_user_msg})
conversation.append({"role": "assistant", "content": history[0][1]})
for user, assistant in history[1:]:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
else:
# First message in a new conversation
first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
conversation.append({"role": "user", "content": first_message})
# Tokenize with padding and attention mask
input_data = tokenizer.apply_chat_template(conversation, return_tensors="pt", padding=True, truncation=True)
input_ids = input_data.to(model.device)
attention_mask = input_ids.ne(tokenizer.pad_token_id).to(dtype=torch.long, device=model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask, # Fixes the warning
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id, # Explicitly set
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Enter system instructions for the model...",
lines=2
)
gr.ChatInterface(
fn=chat_mistral,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
system_prompt_input,
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="Top-p", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
],
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()
|