mattcracker commited on
Commit
ba12288
·
verified ·
1 Parent(s): 3c0bd1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -46
app.py CHANGED
@@ -1,84 +1,170 @@
 
1
  from threading import Thread
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
4
  import spaces
5
- tokenizer = AutoTokenizer.from_pretrained("agentica-org/DeepScaleR-1.5B-Preview")
6
- model = AutoModelForCausalLM.from_pretrained("agentica-org/DeepScaleR-1.5B-Preview", device_map='auto')
7
-
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def preprocess_messages(history):
10
- messages = []
11
-
12
- for idx, (user_msg, model_msg) in enumerate(history):
13
- if idx == len(history) - 1 and not messages:
14
- messages.append({"role": "user", "content": user_msg})
15
- break
 
 
 
 
16
  if user_msg:
17
- messages.append({"role": "user", "content": user_msg})
18
- if model_msg:
19
- messages.append({"role": "assistant", "content": messages})
20
-
21
- return messages
22
-
23
-
24
- @spaces.GPU()
 
 
 
 
25
  def predict(history, max_length, top_p, temperature):
26
- messages = preprocess_messages(history)
27
- model_inputs = tokenizer.apply_chat_template(
28
- messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
29
- ).to(model.device)
30
- streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  generate_kwargs = {
32
- "input_ids": model_inputs["input_ids"],
33
- "attention_mask": model_inputs["attention_mask"],
34
- "streamer": streamer,
35
  "max_new_tokens": max_length,
36
  "do_sample": True,
37
  "top_p": top_p,
38
  "temperature": temperature,
39
  "repetition_penalty": 1.2,
 
 
 
40
  }
41
 
42
- generate_kwargs['eos_token_id'] = tokenizer.encode("<|user|>")
43
-
44
  t = Thread(target=model.generate, kwargs=generate_kwargs)
45
  t.start()
 
 
 
46
  for new_token in streamer:
47
- if new_token:
48
- history[-1][1] += new_token
49
  yield history
50
 
51
-
 
 
52
  def main():
53
  with gr.Blocks() as demo:
54
- gr.HTML("""<h1 align="center">GLM-Edge-Chat Gradio Demo</h1>""")
55
 
56
- with gr.Row():
57
- with gr.Column(scale=3):
58
- chatbot = gr.Chatbot()
59
 
60
  with gr.Row():
61
  with gr.Column(scale=2):
62
- user_input = gr.Textbox(show_label=True, placeholder="Input...", label="User Input")
 
 
 
 
63
  submitBtn = gr.Button("Submit")
64
  emptyBtn = gr.Button("Clear History")
65
  with gr.Column(scale=1):
66
- max_length = gr.Slider(0, 8192, value=4096, step=1.0, label="Maximum length", interactive=True)
67
- top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
68
- temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
69
-
70
- # Define functions for button actions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def user(query, history):
72
  return "", history + [[query, ""]]
73
 
74
- submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
75
- predict, [chatbot, max_length, top_p, temperature], chatbot
 
 
 
 
 
 
 
 
 
 
76
  )
77
- emptyBtn.click(lambda: (None, None), None, [chatbot], queue=False)
78
 
79
- demo.queue()
80
- demo.launch()
 
 
 
 
 
 
 
81
 
 
 
 
82
 
83
  if __name__ == "__main__":
84
- main()
 
1
+ # app.py
2
  from threading import Thread
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ import torch
6
  import spaces
 
 
 
7
 
8
+ # ---------------------------------------------
9
+ # 1. 加载模型与 Tokenizer
10
+ # ---------------------------------------------
11
+ # 如果你的模型需要加速/量化等特殊配置,可在 from_pretrained() 中添加相应参数
12
+ # 例如 device_map='auto' 或 trust_remote_code=True 等
13
+ model_name = "agentica-org/DeepScaleR-1.5B-Preview"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
16
+
17
+ # 根据需要加上 .half()/.float()/.quantize() 等操作
18
+ # 例如
19
+ # model.half()
20
+ # 或者
21
+ # model = model.quantize(4/8) # 如果你的模型和环境支持
22
+
23
+ # ---------------------------------------------
24
+ # 2. 对话历史处理
25
+ # ---------------------------------------------
26
  def preprocess_messages(history):
27
+ """
28
+ 将所有的用户与回复消息拼成一个文本 prompt。
29
+ 这里仅示例最简单的形式:
30
+ User: ...
31
+ Assistant: ...
32
+ 最后再接上 "Assistant: " 用于提示模型继续生成。
33
+ 你也可以修改为自己需要的对话模板。
34
+ """
35
+ prompt = ""
36
+ for user_msg, assistant_msg in history:
37
  if user_msg:
38
+ prompt += f"User: {user_msg}\n"
39
+ if assistant_msg:
40
+ prompt += f"Assistant: {assistant_msg}\n"
41
+
42
+ # 继续生成时,让模型再续写 "Assistant:"
43
+ prompt += "Assistant: "
44
+ return prompt
45
+
46
+ # ---------------------------------------------
47
+ # 3. 预测函数
48
+ # ---------------------------------------------
49
+ @spaces.GPU() # 使用 huggingface spaces 的 GPU 装饰器
50
  def predict(history, max_length, top_p, temperature):
51
+ """
52
+ 输入为 history(对话历史)和若干超参,输出流式生成的结果。
53
+ 每生成一个 token,就通过 yield 返回给 Gradio,更新界面。
54
+ """
55
+ prompt = preprocess_messages(history)
56
+
57
+ # 组装输入
58
+ inputs = tokenizer(prompt, return_tensors="pt")
59
+ input_ids = inputs["input_ids"].to(model.device)
60
+
61
+ # 使用 TextIteratorStreamer 来实现流式输出
62
+ streamer = TextIteratorStreamer(
63
+ tokenizer=tokenizer,
64
+ timeout=60,
65
+ skip_prompt=True,
66
+ skip_special_tokens=True
67
+ )
68
+
69
  generate_kwargs = {
70
+ "input_ids": input_ids,
 
 
71
  "max_new_tokens": max_length,
72
  "do_sample": True,
73
  "top_p": top_p,
74
  "temperature": temperature,
75
  "repetition_penalty": 1.2,
76
+ "streamer": streamer,
77
+ # 如果需要自定义一些特殊 token 或其他参数可在此补充
78
+ # "eos_token_id": ...
79
  }
80
 
81
+ # 启动一个线程去执行 generate,然后主线程读取流式输出
 
82
  t = Thread(target=model.generate, kwargs=generate_kwargs)
83
  t.start()
84
+
85
+ # history[-1][1] 存放当前最新的 assistant 回复,因此不断累加
86
+ partial_output = ""
87
  for new_token in streamer:
88
+ partial_output += new_token
89
+ history[-1][1] = partial_output
90
  yield history
91
 
92
+ # ---------------------------------------------
93
+ # 4. 搭建 Gradio 界面
94
+ # ---------------------------------------------
95
  def main():
96
  with gr.Blocks() as demo:
97
+ gr.HTML("<h1 align='center'>DeepScaleR-1.5B-Preview Chat Demo</h1>")
98
 
99
+ # 聊天窗口
100
+ chatbot = gr.Chatbot()
 
101
 
102
  with gr.Row():
103
  with gr.Column(scale=2):
104
+ user_input = gr.Textbox(
105
+ show_label=True,
106
+ placeholder="请输入您的问题...",
107
+ label="User Input"
108
+ )
109
  submitBtn = gr.Button("Submit")
110
  emptyBtn = gr.Button("Clear History")
111
  with gr.Column(scale=1):
112
+ max_length = gr.Slider(
113
+ minimum=0,
114
+ maximum=2048, # 根据模型能力自行调整
115
+ value=512,
116
+ step=1,
117
+ label="Max New Tokens",
118
+ interactive=True
119
+ )
120
+ top_p = gr.Slider(
121
+ minimum=0,
122
+ maximum=1,
123
+ value=0.8,
124
+ step=0.01,
125
+ label="Top P",
126
+ interactive=True
127
+ )
128
+ temperature = gr.Slider(
129
+ minimum=0.01,
130
+ maximum=2.0,
131
+ value=0.7,
132
+ step=0.01,
133
+ label="Temperature",
134
+ interactive=True
135
+ )
136
+
137
+ # 用于将用户输入插入到 chatbot 历史中
138
  def user(query, history):
139
  return "", history + [[query, ""]]
140
 
141
+ # Submit:
142
+ # 1) user() -> 新增一条 (user输入,"") 的对话记录
143
+ # 2) predict() -> 基于更新后的 history 进行生成
144
+ submitBtn.click(
145
+ fn=user,
146
+ inputs=[user_input, chatbot],
147
+ outputs=[user_input, chatbot],
148
+ queue=False
149
+ ).then(
150
+ fn=predict,
151
+ inputs=[chatbot, max_length, top_p, temperature],
152
+ outputs=chatbot
153
  )
 
154
 
155
+ # Clear: 清空对话历史
156
+ def clear_history():
157
+ return [], []
158
+ emptyBtn.click(
159
+ fn=clear_history,
160
+ inputs=[],
161
+ outputs=[chatbot, user_input],
162
+ queue=False
163
+ )
164
 
165
+ # 可选:让 Gradio 自动对排队请求进行调度
166
+ demo.queue()
167
+ demo.launch()
168
 
169
  if __name__ == "__main__":
170
+ main()