MOSS550V commited on
Commit
5c795b2
·
1 Parent(s): 94db94a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -134,7 +134,7 @@ def main():
134
 
135
  if ptuning_checkpoint is not None:
136
  print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
137
- model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True).float()
138
  prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"), map_location='cpu')
139
  new_prefix_state_dict = {}
140
  for k, v in prefix_state_dict.items():
@@ -142,12 +142,12 @@ def main():
142
  new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
143
  model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
144
  else:
145
- model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True).float()
146
 
147
  model = model.quantize(4)
148
 
149
  # P-tuning v2
150
- model = model.half()
151
  model.transformer.prefix_encoder.float()
152
 
153
  model = model.eval()
 
134
 
135
  if ptuning_checkpoint is not None:
136
  print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
137
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
138
  prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"), map_location='cpu')
139
  new_prefix_state_dict = {}
140
  for k, v in prefix_state_dict.items():
 
142
  new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
143
  model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
144
  else:
145
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
146
 
147
  model = model.quantize(4)
148
 
149
  # P-tuning v2
150
+ # model = model.half()
151
  model.transformer.prefix_encoder.float()
152
 
153
  model = model.eval()