llama-cpp-agent / app.py
pabloce's picture
Update app.py
e2f7d5c verified
raw
history blame
2.8 kB
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# client = InferenceClient("cognitivecomputations/dolphin-2.8-mistral-7b-v02")
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
torch.set_default_device("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
history_transformer_format = history + [[message, ""]]
system_prompt = f"<|im_start|>system\n{system_message}.<|im_end|>"
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
top_p=top_p,
top_k=50,
temperature=temperature,
num_beams=1
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '<|im_end|>' in partial_message:
break
yield partial_message
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
theme=gr.themes.Soft(primary_hue="green", secondary_hue="indigo", neutral_hue="zinc",font=[gr.themes.GoogleFont("Exo 2"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
block_background_fill_dark="*neutral_800"
)
)
if __name__ == "__main__":
demo.launch()