Update pipeline.py
Browse files- pipeline.py +5 -4
pipeline.py
CHANGED
@@ -30,7 +30,7 @@ def topk(probs, n=9):
|
|
30 |
tokenId = topIx[choice][0]
|
31 |
return int(tokenId)
|
32 |
|
33 |
-
def model_infer(model, tokenizer, review, max_length=
|
34 |
result_text = []
|
35 |
for i in range(6):
|
36 |
|
@@ -61,13 +61,14 @@ def model_infer(model, tokenizer, review, max_length=300):
|
|
61 |
|
62 |
# If the chosen token is EOS, return the result
|
63 |
if res_id == tokenizer.eos_token_id:
|
64 |
-
|
|
|
65 |
else: # Append to the sequence
|
66 |
result.append(res_id)
|
67 |
|
68 |
# IF no EOS is generated, return after the max_len
|
69 |
-
result_text.append(tokenizer.decode(result))
|
70 |
-
return sorted(result_text, key=len)[
|
71 |
|
72 |
class PreTrainedPipeline():
|
73 |
def __init__(self, path=""):
|
|
|
30 |
tokenId = topIx[choice][0]
|
31 |
return int(tokenId)
|
32 |
|
33 |
+
def model_infer(model, tokenizer, review, max_length=10):
|
34 |
result_text = []
|
35 |
for i in range(6):
|
36 |
|
|
|
61 |
|
62 |
# If the chosen token is EOS, return the result
|
63 |
if res_id == tokenizer.eos_token_id:
|
64 |
+
result_text.append(tokenizer.decode(result)[len(review):])
|
65 |
+
break
|
66 |
else: # Append to the sequence
|
67 |
result.append(res_id)
|
68 |
|
69 |
# IF no EOS is generated, return after the max_len
|
70 |
+
#result_text.append(tokenizer.decode(result))
|
71 |
+
return sorted(result_text, key=len)[4]
|
72 |
|
73 |
class PreTrainedPipeline():
|
74 |
def __init__(self, path=""):
|