MOSS550V commited on
Commit
ae3df97
·
1 Parent(s): e5f9abd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -2,13 +2,22 @@ from transformers import AutoModel, AutoTokenizer
2
  import gradio as gr
3
  import mdtex2html
4
 
5
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
6
- model = AutoModel.from_pretrained("MOSS550V/divination", trust_remote_code=True)
7
  model = model.quantize(4)
8
  model = model.half().cuda()
9
  model.transformer.prefix_encoder.float()
10
  model = model.eval()
11
 
 
 
 
 
 
 
 
 
 
12
  """Override Chatbot.postprocess"""
13
 
14
 
 
2
  import gradio as gr
3
  import mdtex2html
4
 
5
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
6
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
7
  model = model.quantize(4)
8
  model = model.half().cuda()
9
  model.transformer.prefix_encoder.float()
10
  model = model.eval()
11
 
12
+ CHECKPOINT_PATH = "MOSS550V/divination"
13
+
14
+ prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
15
+ new_prefix_state_dict = {}
16
+ for k, v in prefix_state_dict.items():
17
+ if k.startswith("transformer.prefix_encoder."):
18
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
19
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
20
+
21
  """Override Chatbot.postprocess"""
22
 
23