studyinglover commited on
Commit
b9f52a3
·
1 Parent(s): d90f8fe
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # 加载 tokenizer 和模型
7
+ tokenizer_path = "studyinglover/IntelliKernel-0.03b-sft"
8
+ model_path = "studyinglover/IntelliKernel-0.03b-sft"
9
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
10
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+
15
+ # 定义一个生成回复的函数
16
+ def chat_with_model(history, user_input, top_k, temperature):
17
+ # 将用户输入追加到对话历史
18
+ history.append({"role": "user", "content": user_input})
19
+
20
+ # 生成新提示
21
+ new_prompt = tokenizer.apply_chat_template(
22
+ history, tokenize=False, add_generation_prompt=True
23
+ )[-(model.config.max_seq_len - 1) :]
24
+
25
+ # 编码输入并发送到设备
26
+ x = tokenizer(new_prompt, return_tensors="pt").input_ids.to(device)
27
+
28
+ # 使用模型生成回复并计时
29
+ output_text = ""
30
+ start_time = time.time()
31
+ with torch.inference_mode():
32
+ _output = model.generate(
33
+ x,
34
+ tokenizer.eos_token_id,
35
+ max_new_tokens=512,
36
+ top_k=top_k,
37
+ temperature=temperature,
38
+ stream=True,
39
+ )
40
+
41
+ for i in _output:
42
+ output = tokenizer.decode(i[0].tolist())
43
+ output_text += output
44
+
45
+ end_time = time.time()
46
+ elapsed_time = end_time - start_time
47
+ num_tokens = len(tokenizer.encode(output_text))
48
+ token_speed = num_tokens / elapsed_time if elapsed_time > 0 else 0
49
+
50
+ # 更新最新对话的 token 数量和生成速度
51
+ token_info = (
52
+ f"Token 数量: {num_tokens}\nToken 输出速度: {token_speed:.2f} tokens/sec"
53
+ )
54
+
55
+ # 将模型回复加入对话历史
56
+ history.append({"role": "assistant", "content": output_text.strip()})
57
+
58
+ # 返回更新后的对话历史和 token 信息
59
+ return history, "", token_info
60
+
61
+
62
+ # 使用 Gradio 构建对话机器人界面
63
+ with gr.Blocks() as iface:
64
+ with gr.Row():
65
+ with gr.Column(scale=1):
66
+ # 左侧参数控制区域
67
+ top_k_slider = gr.Slider(0, 100, value=50, step=1, label="Top-k")
68
+ temp_slider = gr.Slider(0.1, 1.5, value=1.0, step=0.1, label="Temperature")
69
+ token_info_box = gr.Markdown(
70
+ "Token 数量: \nToken 输出速度: "
71
+ ) # 显示 token 信息的框
72
+ with gr.Column(scale=3):
73
+ # 右侧对话区域
74
+ gr.Markdown(
75
+ "# Chat with AI\n这是一个简单的聊天模型界面,输入内容后模型将生成相应的回复。"
76
+ )
77
+ chatbot = gr.Chatbot(type="messages") # 使用 "messages" 类型记录对话
78
+ msg = gr.Textbox(label="Your Message") # 用户输入框
79
+ with gr.Row():
80
+ send_btn = gr.Button("Send Message") # 发送消息按钮
81
+ clear = gr.Button("Clear Chat") # 清除聊天记录按钮
82
+
83
+ # 设置交互逻辑
84
+ send_btn.click(
85
+ chat_with_model,
86
+ [chatbot, msg, top_k_slider, temp_slider],
87
+ [chatbot, msg, token_info_box],
88
+ ) # 发送消息
89
+ msg.submit(
90
+ chat_with_model,
91
+ [chatbot, msg, top_k_slider, temp_slider],
92
+ [chatbot, msg, token_info_box],
93
+ ) # 按回车发送
94
+ clear.click(lambda: None, None, chatbot, queue=False) # 清除聊天记录
95
+
96
+ iface.launch()