Baicai003 commited on
Commit
675b3e3
·
1 Parent(s): a72bbb0

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -0
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ datasets:
4
+ - shareAI/ShareGPT-Chinese-English-90k
5
+ - shareAI/CodeChat
6
+ language:
7
+ - zh
8
+ library_name: transformers
9
+ tags:
10
+ - code
11
+ ---
12
+
13
+ 用于多轮对话的推理代码:
14
+ ```
15
+ # from Firefly
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+ import torch
18
+
19
+
20
+ def main():
21
+ model_name = 'shareAI/CodeLLaMA-chat-13b-Chinese'
22
+
23
+ device = 'cuda'
24
+ max_new_tokens = 500 # 每轮对话最多生成多少个token
25
+ history_max_len = 1000 # 模型记忆的最大token长度
26
+ top_p = 0.9
27
+ temperature = 0.35
28
+ repetition_penalty = 1.0
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ trust_remote_code=True,
33
+ low_cpu_mem_usage=True,
34
+ torch_dtype=torch.float16,
35
+ device_map='auto'
36
+ ).to(device).eval()
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_name,
39
+ trust_remote_code=True,
40
+ use_fast=False
41
+ )
42
+
43
+
44
+ history_token_ids = torch.tensor([[]], dtype=torch.long)
45
+
46
+ user_input = input('User:')
47
+ while True:
48
+ input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
49
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
50
+ user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
51
+ history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
52
+ model_input_ids = history_token_ids[:, -history_max_len:].to(device)
53
+ with torch.no_grad():
54
+ outputs = model.generate(
55
+ input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
56
+ temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
57
+ )
58
+ model_input_ids_len = model_input_ids.size(1)
59
+ response_ids = outputs[:, model_input_ids_len:]
60
+ history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
61
+ response = tokenizer.batch_decode(response_ids)
62
+ print("Bot:" + response[0].strip().replace(tokenizer.eos_token, ""))
63
+ user_input = input('User:')
64
+
65
+
66
+ if __name__ == '__main__':
67
+ main()
68
+ ```