Text Generation
Transformers
PyTorch
skywork
custom_code
zhao1iang commited on
Commit
1094366
Β·
1 Parent(s): 304814c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -14
README.md CHANGED
@@ -120,21 +120,17 @@ def special_encode(input, tokenizer):
120
 
121
  return res_id
122
 
123
- def special_encode(input, tokenizer):
124
- raw_str = "[USER]%s[SEP][BOT]" % input.strip().replace("\r", "")
125
- eos_id = tokenizer.eos_token_id
126
- bos_id = tokenizer.bos_token_id
127
- sep_id = tokenizer.encode("[SEP]")[-1]
128
- res_id = [eos_id, bos_id]
129
- arr = raw_str.split("[SEP]")
130
- for elem_idx in range(len(arr)):
131
- elem = arr[elem_idx]
132
- elem_id = tokenizer.encode(elem)[1:]
133
- res_id += elem_id
134
- if elem_idx < len(arr) - 1:
135
- res_id.append(sep_id)
136
 
137
- return res_id
138
 
139
  if __name__ == '__main__':
140
  text = "ε°ηŽ‹θ¦ε°†150千克含药量20%ηš„ε†œθ―η¨€ι‡Šζˆε«θ―ι‡5%ηš„θ―ζ°΄οΌŽιœ€θ¦εŠ ζ°΄ε€šε°‘εƒε…‹οΌŸ"
 
120
 
121
  return res_id
122
 
123
+ def extract_res(response):
124
+ if "[BOT]" in response:
125
+ response = response.split("[BOT]")[1]
126
+ if "<s>" in response:
127
+ response = response.split("<s>")[-1]
128
+ if "</s>" in response:
129
+ response = response.split("</s>")[0]
130
+ if "[SEP]" in response:
131
+ response = response.split("[SEP]")[0]
132
+ return response
 
 
 
133
 
 
134
 
135
  if __name__ == '__main__':
136
  text = "ε°ηŽ‹θ¦ε°†150千克含药量20%ηš„ε†œθ―η¨€ι‡Šζˆε«θ―ι‡5%ηš„θ―ζ°΄οΌŽιœ€θ¦εŠ ζ°΄ε€šε°‘εƒε…‹οΌŸ"