ClueAI commited on
Commit
5085b6b
·
1 Parent(s): 123c88a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -14,17 +14,17 @@ def preprocess(text):
14
  return text
15
 
16
  def postprocess(text):
17
- return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')
18
 
19
- def answer(text, sample=True, top_p=1, temperature=0.7):
20
  '''sample:是否抽样。生成任务,可以设置为True;
21
  top_p:0-1之间,生成的内容越多样'''
22
  text = preprocess(text)
23
- encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
24
  if not sample:
25
- out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
26
  else:
27
- out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
28
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
29
  return postprocess(out_text[0])
30
 
 
14
  return text
15
 
16
  def postprocess(text):
17
+ return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ').replace(" ", " ")
18
 
19
+ def answer(text, sample=True, top_p=0.9, temperature=0.7):
20
  '''sample:是否抽样。生成任务,可以设置为True;
21
  top_p:0-1之间,生成的内容越多样'''
22
  text = preprocess(text)
23
+ encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
24
  if not sample:
25
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
26
  else:
27
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
28
  out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
29
  return postprocess(out_text[0])
30