zencorn commited on
Commit
6265b0a
·
1 Parent(s): b819f19

Update application file

Browse files
Files changed (2) hide show
  1. Oldapp.py +30 -0
  2. app.py +291 -28
Oldapp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import runpy
4
+ st.set_page_config(layout="wide", page_title="My Multi-Page App")
5
+ def set_env_variable(key, value):
6
+ os.environ[key] = value
7
+ def home_page():
8
+ st.header("欢迎来到首页")
9
+ # 设置输入框为隐私状态
10
+ token = st.text_input("请输入浦语token:", type="password", key="token")
11
+ weather_token = st.text_input("请输入和风天气token:", type="password", key="weather_token")
12
+ if st.button("保存并体验agent"):
13
+ if token and weather_token:
14
+ set_env_variable("token", token) # 设置环境变量为 'token'
15
+ set_env_variable("weather_token", weather_token) # 设置环境变量为 'weather_token'
16
+ st.session_state.token_entered = True
17
+ st.rerun()
18
+ else:
19
+ st.error("请输入所有token")
20
+ if 'token_entered' not in st.session_state:
21
+ st.session_state.token_entered = False
22
+ if not st.session_state.token_entered:
23
+ home_page()
24
+ else:
25
+ # 动态加载子页面
26
+ page = st.sidebar.radio("选择页面", ["天气查询助手", "博客写作助手"])
27
+ if page == "天气查询助手":
28
+ runpy.run_path("examples/agent_api_web_demo.py", run_name="__main__")
29
+ elif page == "博客写作助手":
30
+ runpy.run_path("examples/multi_agents_api_web_demo.py", run_name="__main__")
app.py CHANGED
@@ -1,30 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
- import os
3
- import runpy
4
- st.set_page_config(layout="wide", page_title="My Multi-Page App")
5
- def set_env_variable(key, value):
6
- os.environ[key] = value
7
- def home_page():
8
- st.header("欢迎来到首页")
9
- # 设置输入框为隐私状态
10
- token = st.text_input("请输入浦语token:", type="password", key="token")
11
- weather_token = st.text_input("请输入和风天气token:", type="password", key="weather_token")
12
- if st.button("保存并体验agent"):
13
- if token and weather_token:
14
- set_env_variable("token", token) # 设置环境变量为 'token'
15
- set_env_variable("weather_token", weather_token) # 设置环境变量为 'weather_token'
16
- st.session_state.token_entered = True
17
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  else:
19
- st.error("请输入所有token")
20
- if 'token_entered' not in st.session_state:
21
- st.session_state.token_entered = False
22
- if not st.session_state.token_entered:
23
- home_page()
24
- else:
25
- # 动态加载子页面
26
- page = st.sidebar.radio("选择页面", ["天气查询助手", "博客写作助手"])
27
- if page == "天气查询助手":
28
- runpy.run_path("examples/agent_api_web_demo.py", run_name="__main__")
29
- elif page == "博客写作助手":
30
- runpy.run_path("examples/multi_agents_api_web_demo.py", run_name="__main__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+
4
+ We mainly modified part of the code logic to adapt to the
5
+ generation of our model.
6
+ Please refer to these links below for more information:
7
+ 1. streamlit chat example:
8
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
9
+ 2. chatglm2:
10
+ https://github.com/THUDM/ChatGLM2-6B
11
+ 3. transformers:
12
+ https://github.com/huggingface/transformers
13
+ Please run with the command `streamlit run path/to/web_demo.py
14
+ --server.address=0.0.0.0 --server.port 7860`.
15
+ Using `python path/to/web_demo.py` may cause unknown problems.
16
+ """
17
+ # isort: skip_file
18
+ import copy
19
+ import warnings
20
+ from dataclasses import asdict, dataclass
21
+ from typing import Callable, List, Optional
22
+
23
  import streamlit as st
24
+ import torch
25
+ from torch import nn
26
+ from transformers.generation.utils import (LogitsProcessorList,
27
+ StoppingCriteriaList)
28
+ from transformers.utils import logging
29
+
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
31
+
32
+ logger = logging.get_logger(__name__)
33
+ '''model_name_or_path="/root/finetune/work_dirs/assistTuner/merged"'''
34
+
35
+ model_name_or_path = "/root/finetune/work_dirs/assistTuner/merged"
36
+ @dataclass
37
+ class GenerationConfig:
38
+ # this config is used for chat to provide more diversity
39
+ max_length: int = 32768
40
+ top_p: float = 0.8
41
+ temperature: float = 0.8
42
+ do_sample: bool = True
43
+ repetition_penalty: float = 1.005
44
+
45
+
46
+ @torch.inference_mode()
47
+ def generate_interactive(
48
+ model,
49
+ tokenizer,
50
+ prompt,
51
+ generation_config: Optional[GenerationConfig] = None,
52
+ logits_processor: Optional[LogitsProcessorList] = None,
53
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
54
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
55
+ List[int]]] = None,
56
+ additional_eos_token_id: Optional[int] = None,
57
+ **kwargs,
58
+ ):
59
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
60
+ input_length = len(inputs['input_ids'][0])
61
+ for k, v in inputs.items():
62
+ inputs[k] = v.cuda()
63
+ input_ids = inputs['input_ids']
64
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
65
+ if generation_config is None:
66
+ generation_config = model.generation_config
67
+ generation_config = copy.deepcopy(generation_config)
68
+ model_kwargs = generation_config.update(**kwargs)
69
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
70
+ generation_config.bos_token_id,
71
+ generation_config.eos_token_id,
72
+ )
73
+ if isinstance(eos_token_id, int):
74
+ eos_token_id = [eos_token_id]
75
+ if additional_eos_token_id is not None:
76
+ eos_token_id.append(additional_eos_token_id)
77
+ has_default_max_length = kwargs.get(
78
+ 'max_length') is None and generation_config.max_length is not None
79
+ if has_default_max_length and generation_config.max_new_tokens is None:
80
+ warnings.warn(
81
+ f"Using 'max_length''s default \
82
+ ({repr(generation_config.max_length)}) \
83
+ to control the generation length. "
84
+ 'This behaviour is deprecated and will be removed from the \
85
+ config in v5 of Transformers -- we'
86
+ ' recommend using `max_new_tokens` to control the maximum \
87
+ length of the generation.',
88
+ UserWarning,
89
+ )
90
+ elif generation_config.max_new_tokens is not None:
91
+ generation_config.max_length = generation_config.max_new_tokens + \
92
+ input_ids_seq_length
93
+ if not has_default_max_length:
94
+ logger.warn( # pylint: disable=W4902
95
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
96
+ f"and 'max_length'(={generation_config.max_length}) seem to "
97
+ "have been set. 'max_new_tokens' will take precedence. "
98
+ 'Please refer to the documentation for more information. '
99
+ '(https://huggingface.co/docs/transformers/main/'
100
+ 'en/main_classes/text_generation)',
101
+ UserWarning,
102
+ )
103
+
104
+ if input_ids_seq_length >= generation_config.max_length:
105
+ input_ids_string = 'input_ids'
106
+ logger.warning(
107
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
108
+ f"but 'max_length' is set to {generation_config.max_length}. "
109
+ 'This can lead to unexpected behavior. You should consider'
110
+ " increasing 'max_new_tokens'.")
111
+
112
+ # 2. Set generation parameters if not already defined
113
+ logits_processor = logits_processor if logits_processor is not None \
114
+ else LogitsProcessorList()
115
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
116
+ else StoppingCriteriaList()
117
+
118
+ logits_processor = model._get_logits_processor(
119
+ generation_config=generation_config,
120
+ input_ids_seq_length=input_ids_seq_length,
121
+ encoder_input_ids=input_ids,
122
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
123
+ logits_processor=logits_processor,
124
+ )
125
+
126
+ stopping_criteria = model._get_stopping_criteria(
127
+ generation_config=generation_config,
128
+ stopping_criteria=stopping_criteria)
129
+ logits_warper = model._get_logits_warper(generation_config)
130
+
131
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
132
+ scores = None
133
+ while True:
134
+ model_inputs = model.prepare_inputs_for_generation(
135
+ input_ids, **model_kwargs)
136
+ # forward pass to get next token
137
+ outputs = model(
138
+ **model_inputs,
139
+ return_dict=True,
140
+ output_attentions=False,
141
+ output_hidden_states=False,
142
+ )
143
+
144
+ next_token_logits = outputs.logits[:, -1, :]
145
+
146
+ # pre-process distribution
147
+ next_token_scores = logits_processor(input_ids, next_token_logits)
148
+ next_token_scores = logits_warper(input_ids, next_token_scores)
149
+
150
+ # sample
151
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
152
+ if generation_config.do_sample:
153
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
154
  else:
155
+ next_tokens = torch.argmax(probs, dim=-1)
156
+
157
+ # update generated ids, model inputs, and length for next step
158
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
159
+ model_kwargs = model._update_model_kwargs_for_generation(
160
+ outputs, model_kwargs, is_encoder_decoder=False)
161
+ unfinished_sequences = unfinished_sequences.mul(
162
+ (min(next_tokens != i for i in eos_token_id)).long())
163
+
164
+ output_token_ids = input_ids[0].cpu().tolist()
165
+ output_token_ids = output_token_ids[input_length:]
166
+ for each_eos_token_id in eos_token_id:
167
+ if output_token_ids[-1] == each_eos_token_id:
168
+ output_token_ids = output_token_ids[:-1]
169
+ response = tokenizer.decode(output_token_ids)
170
+
171
+ yield response
172
+ # stop when each sentence is finished
173
+ # or if we exceed the maximum length
174
+ if unfinished_sequences.max() == 0 or stopping_criteria(
175
+ input_ids, scores):
176
+ break
177
+
178
+
179
+ def on_btn_click():
180
+ del st.session_state.messages
181
+
182
+
183
+ @st.cache_resource
184
+ def load_model():
185
+ model = (AutoModelForCausalLM.from_pretrained(
186
+ model_name_or_path,
187
+ trust_remote_code=True).to(torch.bfloat16).cuda())
188
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
189
+ trust_remote_code=True)
190
+ return model, tokenizer
191
+
192
+
193
+ def prepare_generation_config():
194
+ with st.sidebar:
195
+ max_length = st.slider('Max Length',
196
+ min_value=8,
197
+ max_value=32768,
198
+ value=32768)
199
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
200
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
201
+ st.button('Clear Chat History', on_click=on_btn_click)
202
+
203
+ generation_config = GenerationConfig(max_length=max_length,
204
+ top_p=top_p,
205
+ temperature=temperature)
206
+
207
+ return generation_config
208
+
209
+
210
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
211
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
212
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
213
+ <|im_start|>assistant\n'
214
+
215
+
216
+ def combine_history(prompt):
217
+ messages = st.session_state.messages
218
+ meta_instruction = ('You are a helpful, honest, '
219
+ 'and harmless AI assistant.')
220
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
221
+ for message in messages:
222
+ cur_content = message['content']
223
+ if message['role'] == 'user':
224
+ cur_prompt = user_prompt.format(user=cur_content)
225
+ elif message['role'] == 'robot':
226
+ cur_prompt = robot_prompt.format(robot=cur_content)
227
+ else:
228
+ raise RuntimeError
229
+ total_prompt += cur_prompt
230
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
231
+ return total_prompt
232
+
233
+
234
+ def main():
235
+ st.title('internlm2_5-7b-chat-assistant')
236
+
237
+ # torch.cuda.empty_cache()
238
+ print('load model begin.')
239
+ model, tokenizer = load_model()
240
+ print('load model end.')
241
+
242
+ generation_config = prepare_generation_config()
243
+
244
+ # Initialize chat history
245
+ if 'messages' not in st.session_state:
246
+ st.session_state.messages = []
247
+
248
+ # Display chat messages from history on app rerun
249
+ for message in st.session_state.messages:
250
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
251
+ st.markdown(message['content'])
252
+
253
+ # Accept user input
254
+ if prompt := st.chat_input('What is up?'):
255
+ # Display user message in chat message container
256
+
257
+ with st.chat_message('user', avatar='user'):
258
+
259
+ st.markdown(prompt)
260
+ real_prompt = combine_history(prompt)
261
+ # Add user message to chat history
262
+ st.session_state.messages.append({
263
+ 'role': 'user',
264
+ 'content': prompt,
265
+ 'avatar': 'user'
266
+ })
267
+
268
+ with st.chat_message('robot', avatar='assistant'):
269
+
270
+ message_placeholder = st.empty()
271
+ for cur_response in generate_interactive(
272
+ model=model,
273
+ tokenizer=tokenizer,
274
+ prompt=real_prompt,
275
+ additional_eos_token_id=92542,
276
+ device='cuda:0',
277
+ **asdict(generation_config),
278
+ ):
279
+ # Display robot response in chat message container
280
+ message_placeholder.markdown(cur_response + '▌')
281
+ message_placeholder.markdown(cur_response)
282
+ # Add robot response to chat history
283
+ st.session_state.messages.append({
284
+ 'role': 'robot',
285
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
286
+ 'avatar': 'assistant',
287
+ })
288
+ torch.cuda.empty_cache()
289
+
290
+
291
+ if __name__ == '__main__':
292
+ main()
293
+