File size: 4,676 Bytes
6f8934c abd9fd7 6f8934c abd9fd7 6f8934c 2ad05c2 6f8934c abd9fd7 6f8934c abd9fd7 6f8934c 47dded2 abd9fd7 6f8934c 47dded2 abd9fd7 de6224b abd9fd7 6f8934c abd9fd7 6f8934c de6224b 6f8934c de6224b 47dded2 11dd5a9 47dded2 f1d7efb 47dded2 804e8e0 6f8934c abd9fd7 804e8e0 6f8934c de6224b 6f8934c 804e8e0 6f8934c abd9fd7 6f8934c b951ea5 6f8934c b951ea5 6f8934c b951ea5 6f8934c dffea0f 6f8934c 2a400f0 6f8934c be15139 6f8934c abd9fd7 de6224b abd9fd7 6f8934c 3835819 abd9fd7 3835819 abd9fd7 6f8934c 2a400f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Test Model</h1>
</div>
'''
LICENSE = """
<p/>
---
"""
PLACEHOLDER = """
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored")
model = AutoModelForCausalLM.from_pretrained("Orenguteng/Llama-3-8B-Lexi-Uncensored", device_map="auto") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_llama3_8b(message: str,
history: list,
temperature: float,
max_new_tokens: int,
top_p: float,
system_prompt: str
) -> str:
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
top_p (float): The top_p value for nucleus sampling.
system_prompt (str): The system prompt to guide the conversation.
Returns:
str: The generated response.
"""
conversation = [{"role": "system", "content": system_prompt}]
for user, assistant in history:
conversation.append({"role": "user", "content": user})
conversation.append({"role": "assistant", "content": assistant})
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
attention_mask=attention_mask,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
first_chunk = True
for text in streamer:
if first_chunk and text.startswith("assistant"):
text = text[len("assistant"):].lstrip(": \n") # Remove "assistant" and any following symbols
first_chunk = False
outputs.append(text)
yield "".join(outputs)
# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface', type='messages')
with gr.Blocks(fill_height=True, css=css) as aida:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=None,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=4096,
label="Max new tokens",
render=False ),
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.9,
label="Top_p",
render=False),
gr.Textbox(lines=2,
placeholder="Enter system prompt here...",
label="System Prompt",
render=False),
],
examples=[
['Who Are you?']
],
cache_examples=False,
)
if __name__ == "__main__":
aida.launch(share=True)
|