MOSS550V commited on
Commit
797f396
·
1 Parent(s): ae3df97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -18
app.py CHANGED
@@ -1,22 +1,26 @@
1
- from transformers import AutoModel, AutoTokenizer
 
2
  import gradio as gr
3
  import mdtex2html
4
 
5
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
6
- model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
7
- model = model.quantize(4)
8
- model = model.half().cuda()
9
- model.transformer.prefix_encoder.float()
10
- model = model.eval()
 
 
 
 
 
 
 
 
11
 
12
- CHECKPOINT_PATH = "MOSS550V/divination"
13
 
14
- prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
15
- new_prefix_state_dict = {}
16
- for k, v in prefix_state_dict.items():
17
- if k.startswith("transformer.prefix_encoder."):
18
- new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
19
- model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
20
 
21
  """Override Chatbot.postprocess"""
22
 
@@ -86,7 +90,7 @@ def reset_state():
86
 
87
 
88
  with gr.Blocks() as demo:
89
- gr.HTML("""<h1 align="center">预测</h1>""")
90
 
91
  chatbot = gr.Chatbot()
92
  with gr.Row():
@@ -98,9 +102,9 @@ with gr.Blocks() as demo:
98
  submitBtn = gr.Button("Submit", variant="primary")
99
  with gr.Column(scale=1):
100
  emptyBtn = gr.Button("Clear History")
101
- max_length = gr.Slider(0, 4096, value=64, step=1.0, label="Maximum length", interactive=True)
102
  top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
103
- temperature = gr.Slider(0, 1, value=0.45, step=0.01, label="Temperature", interactive=True)
104
 
105
  history = gr.State([])
106
 
@@ -110,4 +114,55 @@ with gr.Blocks() as demo:
110
 
111
  emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
112
 
113
- demo.queue().launch(share=False, inbrowser=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
  import gradio as gr
4
  import mdtex2html
5
 
6
+ import torch
7
+ import transformers
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ AutoTokenizer,
13
+ DataCollatorForSeq2Seq,
14
+ HfArgumentParser,
15
+ Seq2SeqTrainingArguments,
16
+ set_seed,
17
+ )
18
+
19
+ from arguments import ModelArguments, DataTrainingArguments
20
 
 
21
 
22
+ model = None
23
+ tokenizer = None
 
 
 
 
24
 
25
  """Override Chatbot.postprocess"""
26
 
 
90
 
91
 
92
  with gr.Blocks() as demo:
93
+ gr.HTML("""<h1 align="center">ChatGLM</h1>""")
94
 
95
  chatbot = gr.Chatbot()
96
  with gr.Row():
 
102
  submitBtn = gr.Button("Submit", variant="primary")
103
  with gr.Column(scale=1):
104
  emptyBtn = gr.Button("Clear History")
105
+ max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
106
  top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
107
+ temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
108
 
109
  history = gr.State([])
110
 
 
114
 
115
  emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
116
 
117
+
118
+
119
+ def main():
120
+ global model, tokenizer
121
+
122
+ parser = HfArgumentParser((
123
+ ModelArguments))
124
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
125
+ # If we pass only one argument to the script and it's the path to a json file,
126
+ # let's parse it to get our arguments.
127
+ model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
128
+ else:
129
+ model_args = parser.parse_args_into_dataclasses()[0]
130
+
131
+ tokenizer = AutoTokenizer.from_pretrained(
132
+ model_args.model_name_or_path, trust_remote_code=True)
133
+ config = AutoConfig.from_pretrained(
134
+ model_args.model_name_or_path, trust_remote_code=True)
135
+
136
+ config.pre_seq_len = model_args.pre_seq_len
137
+ config.prefix_projection = model_args.prefix_projection
138
+
139
+ ptuning_checkpoint = "MOSS550V/divination"
140
+
141
+ if ptuning_checkpoint is not None:
142
+ print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
143
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
144
+ prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
145
+ new_prefix_state_dict = {}
146
+ for k, v in prefix_state_dict.items():
147
+ if k.startswith("transformer.prefix_encoder."):
148
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
149
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
150
+ else:
151
+ model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
152
+
153
+ if model_args.quantization_bit is not None:
154
+ print(f"Quantized to {model_args.quantization_bit} bit")
155
+ model = model.quantize(model_args.quantization_bit)
156
+
157
+ if model_args.pre_seq_len is not None:
158
+ # P-tuning v2
159
+ model = model.half().cuda()
160
+ model.transformer.prefix_encoder.float().cuda()
161
+
162
+ model = model.eval()
163
+ demo.queue().launch(share=False, inbrowser=True)
164
+
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()