File size: 3,928 Bytes
6f8934c 47dded2 6f8934c 47dded2 6f8934c 47dded2 f1d7efb 6f8934c 47dded2 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 69d0c7f 47dded2 69d0c7f 47dded2 6f8934c 47dded2 f1d7efb 47dded2 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 69d0c7f 6f8934c 47dded2 6f8934c 69d0c7f 6f8934c 47dded2 69d0c7f f1d7efb 69d0c7f 6f8934c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Mistral Chat</h1>
</div>
'''
LICENSE = """
<p/>
---
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Mistral Chat 8B</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored")
model = AutoModelForCausalLM.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored", device_map="auto")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_llama3_8b(message: str,
history: list,
temperature: float,
max_new_tokens: int,
system_prompt: str) -> str:
"""
Generate a streaming response using the Mistral-8B model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
system_prompt (str): The system prompt to guide the assistant's behavior.
Returns:
str: The generated response.
"""
conversation = []
# Include system prompt at the beginning if provided
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Enter system instructions for the model...",
lines=2
)
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
system_prompt_input,
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
],
examples=[
['Are you a sentient being?']
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()
|