MOSS550V commited on
Commit
4d5ad0d
·
1 Parent(s): 64a842e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -16
app.py CHANGED
@@ -121,26 +121,20 @@ def main():
121
 
122
  parser = HfArgumentParser((
123
  ModelArguments))
124
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
125
- # If we pass only one argument to the script and it's the path to a json file,
126
- # let's parse it to get our arguments.
127
- model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
128
- else:
129
- model_args = parser.parse_args_into_dataclasses()[0]
130
 
131
  tokenizer = AutoTokenizer.from_pretrained(
132
  "THUDM/chatglm-6b-int4", trust_remote_code=True)
133
  config = AutoConfig.from_pretrained(
134
  "MOSS550V/divination", trust_remote_code=True)
135
 
136
- config.pre_seq_len = model_args.pre_seq_len
137
- config.prefix_projection = model_args.prefix_projection
138
 
139
  ptuning_checkpoint = "MOSS550V/divination"
140
 
141
  if ptuning_checkpoint is not None:
142
  print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
143
- model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
144
  prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
145
  new_prefix_state_dict = {}
146
  for k, v in prefix_state_dict.items():
@@ -150,14 +144,11 @@ def main():
150
  else:
151
  model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
152
 
153
- if model_args.quantization_bit is not None:
154
- print(f"Quantized to {model_args.quantization_bit} bit")
155
- model = model.quantize(model_args.quantization_bit)
156
 
157
- if model_args.pre_seq_len is not None:
158
- # P-tuning v2
159
- model = model.half()
160
- model.transformer.prefix_encoder.float()
161
 
162
  model = model.eval()
163
  demo.queue().launch(share=False, inbrowser=True)
 
121
 
122
  parser = HfArgumentParser((
123
  ModelArguments))
 
 
 
 
 
 
124
 
125
  tokenizer = AutoTokenizer.from_pretrained(
126
  "THUDM/chatglm-6b-int4", trust_remote_code=True)
127
  config = AutoConfig.from_pretrained(
128
  "MOSS550V/divination", trust_remote_code=True)
129
 
130
+ config.pre_seq_len = 128
131
+ config.prefix_projection = false
132
 
133
  ptuning_checkpoint = "MOSS550V/divination"
134
 
135
  if ptuning_checkpoint is not None:
136
  print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
137
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
138
  prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
139
  new_prefix_state_dict = {}
140
  for k, v in prefix_state_dict.items():
 
144
  else:
145
  model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", config=config, trust_remote_code=True)
146
 
147
+ model = model.quantize(4)
 
 
148
 
149
+ # P-tuning v2
150
+ model = model.half()
151
+ model.transformer.prefix_encoder.float()
 
152
 
153
  model = model.eval()
154
  demo.queue().launch(share=False, inbrowser=True)