shuaikang commited on
Commit
96a12df
·
verified ·
1 Parent(s): 90bdb08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -106
app.py CHANGED
@@ -1,110 +1,117 @@
1
-
2
- from transformers import AutoModel, AutoTokenizer
3
  import gradio as gr
4
- import mdtex2html
5
- #from utils import load_model_on_gpus
6
-
7
- tokenizer = AutoTokenizer.from_pretrained("sethuiyer/Medichat-Llama3-8B", trust_remote_code=True)
8
- #model = AutoModel.from_pretrained("sethuiyer/Medichat-Llama3-8B", trust_remote_code=True).cuda()
9
- model = AutoModel.from_pretrained("sethuiyer/Medichat-Llama3-8B", trust_remote_code=True)
10
- # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
11
- # from utils import load_model_on_gpus
12
- # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
13
- model = model.eval()
14
-
15
- """Override Chatbot.postprocess"""
16
-
17
-
18
- def postprocess(self, y):
19
- if y is None:
20
- return []
21
- for i, (message, response) in enumerate(y):
22
- y[i] = (
23
- None if message is None else mdtex2html.convert((message)),
24
- None if response is None else mdtex2html.convert(response),
25
- )
26
- return y
27
-
28
-
29
- gr.Chatbot.postprocess = postprocess
30
-
31
-
32
- def parse_text(text):
33
- """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
34
- lines = text.split("\n")
35
- lines = [line for line in lines if line != ""]
36
- count = 0
37
- for i, line in enumerate(lines):
38
- if "```" in line:
39
- count += 1
40
- items = line.split('`')
41
- if count % 2 == 1:
42
- lines[i] = f'<pre><code class="language-{items[-1]}">'
43
- else:
44
- lines[i] = f'<br></code></pre>'
45
- else:
46
- if i > 0:
47
- if count % 2 == 1:
48
- line = line.replace("`", "\`")
49
- line = line.replace("<", "&lt;")
50
- line = line.replace(">", "&gt;")
51
- line = line.replace(" ", "&nbsp;")
52
- line = line.replace("*", "&ast;")
53
- line = line.replace("_", "&lowbar;")
54
- line = line.replace("-", "&#45;")
55
- line = line.replace(".", "&#46;")
56
- line = line.replace("!", "&#33;")
57
- line = line.replace("(", "&#40;")
58
- line = line.replace(")", "&#41;")
59
- line = line.replace("$", "&#36;")
60
- lines[i] = "<br>"+line
61
- text = "".join(lines)
62
- return text
63
-
64
-
65
- def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
66
- chatbot.append((parse_text(input), ""))
67
- for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
68
- return_past_key_values=True,
69
- max_length=max_length, top_p=top_p,
70
- temperature=temperature):
71
- chatbot[-1] = (parse_text(input), parse_text(response))
72
-
73
- yield chatbot, history, past_key_values
74
-
75
-
76
- def reset_user_input():
77
- return gr.update(value='')
78
-
79
-
80
- def reset_state():
81
- return [], [], None
82
 
83
 
84
  with gr.Blocks() as demo:
85
- gr.HTML("""<h1 align="center">ChatGLM2-6B</h1>""")
86
-
87
- chatbot = gr.Chatbot()
88
- with gr.Row():
89
- with gr.Column(scale=4):
90
- with gr.Column(scale=12):
91
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
92
- container=False)
93
- with gr.Column(min_width=32, scale=1):
94
- submitBtn = gr.Button("Submit", variant="primary")
95
- with gr.Column(scale=1):
96
- emptyBtn = gr.Button("Clear History")
97
- max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
98
- top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
99
- temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
100
-
101
- history = gr.State([])
102
- past_key_values = gr.State(None)
103
-
104
- submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
105
- [chatbot, history, past_key_values], show_progress=True)
106
- submitBtn.click(reset_user_input, [], [user_input])
107
-
108
- emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
109
-
110
- demo.queue().launch(server_name="0.0.0.0",server_port=7860,share=False, inbrowser=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
 
2
  import gradio as gr
3
+ from threading import Thread
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+ device = "cuda" # the device to load the model onto
6
+ #device = "cpu" # the device to load the model onto
7
+
8
+
9
+ bot_avatar = "shuaikang/dl_logo_rect.png" # 聊天机器人头像位置
10
+ user_avatar = "shuaikang/user_avatar.jpg" # 用户头像位置
11
+ #model_path = "sethuiyer/Medichat-Llama3-8B" # 已下载的模型位置
12
+ #model_path = "johnsnowlabs/JSL-MedMX-7X"
13
+ model_path = "aaditya/Llama3-OpenBioLLM-8B"
14
+
15
+ # 存储全局的历史对话记录,Llama3支持系统prompt,所以这里默认设置!
16
+ llama3_chat_history = [
17
+ {"role": "system", "content": "You are a helpful assistant trained by MetaAI! But you are running with DataLearnerAI Code."}
18
+ ]
19
+
20
+ # 初始化所有变量,用于载入模型
21
+ tokenizer = None
22
+ streamer = None
23
+ model = None
24
+ terminators = None
25
+
26
+
27
+ def init_model():
28
+ """初始化模型,载入本地模型
29
+ """
30
+ global tokenizer, model, streamer, terminators
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ model_path, local_files_only=True)
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_path,
36
+ torch_dtype=torch.float16,
37
+ device_map=device,
38
+ trust_remote_code=True
39
+ )
40
+
41
+ terminators = [
42
+ tokenizer.eos_token_id,
43
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
44
+ ]
45
+
46
+ streamer = TextIteratorStreamer(
47
+ tokenizer,
48
+ skip_prompt=True,
49
+ skip_special_tokens=True
50
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  with gr.Blocks() as demo:
54
+ # step1: 载入模型
55
+ init_model()
56
+
57
+ # step2: 初始化gradio的chatbot应用,并添加按钮等信息
58
+ chatbot = gr.Chatbot(
59
+ height=900,
60
+ avatar_images=(user_avatar, bot_avatar)
61
+ )
62
+ msg = gr.Textbox()
63
+ clear = gr.ClearButton([msg, chatbot])
64
+
65
+ # 清楚历史记录
66
+ def clear_history():
67
+ global llama3_chat_history
68
+ llama3_chat_history = []
69
+
70
+ # 用于回复的方法
71
+ def respond(message, chat_history):
72
+
73
+ global llama3_chat_history, tokenizer, model, streamer
74
+
75
+ llama3_chat_history.append({"role": "user", "content": message})
76
+
77
+ # 使用Llama3自带的聊天模板,格式化对话记录
78
+ history_str = tokenizer.apply_chat_template(
79
+ llama3_chat_history,
80
+ tokenize=False,
81
+ add_generation_prompt=True
82
+ )
83
+
84
+ # tokenzier
85
+ inputs = tokenizer(history_str, return_tensors='pt').to(device)
86
+
87
+ chat_history.append([message, ""])
88
+
89
+ generation_kwargs = dict(
90
+ **inputs,
91
+ streamer=streamer,
92
+ max_new_tokens=4096,
93
+ num_beams=1,
94
+ do_sample=True,
95
+ top_p=0.8,
96
+ temperature=0.3,
97
+ eos_token_id=terminators
98
+ )
99
+
100
+ # 启动线程,用以监控流失输出结果
101
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
102
+ thread.start()
103
+
104
+ for new_text in streamer:
105
+ chat_history[-1][1] += new_text
106
+ yield "", chat_history
107
+
108
+ llama3_chat_history.append(
109
+ {"role": "assistant", "content": chat_history[-1][1]}
110
+ )
111
+
112
+ # 点击清楚按钮,触发历史记录清楚
113
+ clear.click(clear_history)
114
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
115
+
116
+ if __name__ == "__main__":
117
+ demo.queue(concurrency_count=1, max_size=1).launch(server_name="0.0.0.0", server_port=7860)