Lin0He commited on
Commit
8a77317
·
1 Parent(s): 984d7ed

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +5 -4
pipeline.py CHANGED
@@ -30,7 +30,7 @@ def topk(probs, n=9):
30
  tokenId = topIx[choice][0]
31
  return int(tokenId)
32
 
33
- def model_infer(model, tokenizer, review, max_length=300):
34
  result_text = []
35
  for i in range(6):
36
 
@@ -61,13 +61,14 @@ def model_infer(model, tokenizer, review, max_length=300):
61
 
62
  # If the chosen token is EOS, return the result
63
  if res_id == tokenizer.eos_token_id:
64
- return tokenizer.decode(result)
 
65
  else: # Append to the sequence
66
  result.append(res_id)
67
 
68
  # IF no EOS is generated, return after the max_len
69
- result_text.append(tokenizer.decode(result))
70
- return sorted(result_text, key=len)[3]
71
 
72
  class PreTrainedPipeline():
73
  def __init__(self, path=""):
 
30
  tokenId = topIx[choice][0]
31
  return int(tokenId)
32
 
33
+ def model_infer(model, tokenizer, review, max_length=10):
34
  result_text = []
35
  for i in range(6):
36
 
 
61
 
62
  # If the chosen token is EOS, return the result
63
  if res_id == tokenizer.eos_token_id:
64
+ result_text.append(tokenizer.decode(result)[len(review):])
65
+ break
66
  else: # Append to the sequence
67
  result.append(res_id)
68
 
69
  # IF no EOS is generated, return after the max_len
70
+ #result_text.append(tokenizer.decode(result))
71
+ return sorted(result_text, key=len)[4]
72
 
73
  class PreTrainedPipeline():
74
  def __init__(self, path=""):