MOSS550V commited on
Commit
f2d4d03
·
1 Parent(s): 4ad5016

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -102,7 +102,7 @@ with gr.Blocks() as demo:
102
  submitBtn = gr.Button("Submit", variant="primary")
103
  with gr.Column(scale=1):
104
  emptyBtn = gr.Button("Clear History")
105
- max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
106
  top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
107
  temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
108
 
@@ -123,9 +123,9 @@ def main():
123
  ModelArguments))
124
 
125
  tokenizer = AutoTokenizer.from_pretrained(
126
- "THUDM/chatglm-6b", trust_remote_code=True)
127
  config = AutoConfig.from_pretrained(
128
- "THUDM/chatglm-6b", trust_remote_code=True)
129
 
130
  config.pre_seq_len = 128
131
  config.prefix_projection = False
@@ -134,15 +134,15 @@ def main():
134
 
135
  if ptuning_checkpoint is not None:
136
  print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
137
- model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
138
- prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"), map_location='cpu')
139
  new_prefix_state_dict = {}
140
  for k, v in prefix_state_dict.items():
141
  if k.startswith("transformer.prefix_encoder."):
142
  new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
143
  model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
144
  else:
145
- model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
146
 
147
  # model = model.quantize(4)
148
 
 
102
  submitBtn = gr.Button("Submit", variant="primary")
103
  with gr.Column(scale=1):
104
  emptyBtn = gr.Button("Clear History")
105
+ max_length = gr.Slider(0, 4096, value=64, step=1.0, label="Maximum length", interactive=True)
106
  top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
107
  temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
108
 
 
123
  ModelArguments))
124
 
125
  tokenizer = AutoTokenizer.from_pretrained(
126
+ "THUDM/chatglm-6b-int4", trust_remote_code=True)
127
  config = AutoConfig.from_pretrained(
128
+ "THUDM/chatglm-6b-int4", trust_remote_code=True)
129
 
130
  config.pre_seq_len = 128
131
  config.prefix_projection = False
 
134
 
135
  if ptuning_checkpoint is not None:
136
  print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
137
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
138
+ prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
139
  new_prefix_state_dict = {}
140
  for k, v in prefix_state_dict.items():
141
  if k.startswith("transformer.prefix_encoder."):
142
  new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
143
  model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
144
  else:
145
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
146
 
147
  # model = model.quantize(4)
148