llama-cpp-agent / app.py
pabloce's picture
Update app.py
6e202b5 verified
raw
history blame
3.02 kB
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# client = InferenceClient("cognitivecomputations/dolphin-2.8-mistral-7b-v02")
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
torch.set_default_device("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"Weyaxi/Einstein-v6.1-Llama3-8B",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"Weyaxi/Einstein-v6.1-Llama3-8B",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
history_transformer_format = history + [[message, ""]]
system_prompt = "<|im_start|>system\nYou are Einstein, a helpful AI assistant.<|im_end|>"
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
top_p=top_p,
top_k=50,
temperature=temperature,
num_beams=1
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '<|im_end|>' in partial_message:
break
yield partial_message
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
theme=gr.themes.Soft(primary_hue="green", secondary_hue="indigo", neutral_hue="zinc",font=[gr.themes.GoogleFont("Exo 2"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
block_background_fill_dark="*neutral_800"
)
)
if __name__ == "__main__":
demo.launch()