Func_calling / app.py
beyoru's picture
Update app.py
9015f33 verified
raw
history blame
2.42 kB
import subprocess
from threading import Thread
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 4096
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
stop_tokens = ["<|endoftext|>", "<|im_end|>", "|im_end|"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for user, assistant in history:
instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(instruction, return_tensors="pt", truncation=True, max_length=CONTEXT_LENGTH)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
if new_token in stop_tokens:
break # Stop generation but don't add the stop token
outputs.append(new_token)
yield "".join(outputs).replace("<|im_end|>", "") # Ensure no leftover stop tokens
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
gr.ChatInterface(
predict,
additional_inputs=[
gr.Textbox("You are a helpful assistant. Format responses clearly using natural Markdown formatting where appropriate.",
label="System prompt"),
gr.Slider(0, 1, 0.6, label="Temperature"),
gr.Slider(0, 4096, 512, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
css=".message { white-space: pre-wrap; }", # Preserve newlines
).queue().launch()