File size: 2,304 Bytes
865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c fd944cc 865e55c fd944cc 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c 104d147 865e55c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import os
os.system('huggingface-cli download matteogeniaccio/phi-4 --local-dir ./phi-4 --include "phi-4/*"')
# 加载 phi-4 模型和 tokenizer
torch.random.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
"./phi-4/phi-4", # 模型路径
device_map="cuda", # 使用 GPU
torch_dtype="auto", # 自动选择数据类型
trust_remote_code=True, # 允许远程代码加载
)
tokenizer = AutoTokenizer.from_pretrained("./phi-4/phi-4")
# 设置 pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
# 响应函数
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# 构造消息内容
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# 将消息转换为字符串格式(适用于 text-generation)
input_text = "\n".join(
f"{msg['role']}: {msg['content']}" for msg in messages
)
# 生成响应
generation_args = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0,
"return_full_text": False,
}
output = pipe(input_text, **generation_args)
response = output[0]["generated_text"]
# 返回流式响应
for token in response:
yield token
# Gradio 界面
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
),
],
)
if __name__ == "__main__":
demo.launch() |