Update handler.py
Browse files- handler.py +35 -47
handler.py
CHANGED
@@ -31,47 +31,43 @@ def topk(probs, n=9):
|
|
31 |
return int(tokenId)
|
32 |
|
33 |
def model_infer(model, tokenizer, review, max_length=300):
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
#choices = [topk(logits) for i in range(5)]
|
48 |
-
choices = topk(logits)
|
49 |
-
result.append(choices)
|
50 |
-
|
51 |
-
# For max_length times:
|
52 |
-
for _ in range(max_length):
|
53 |
-
# Feed the current sequence to the model and make a choice
|
54 |
-
input = torch.tensor(result).unsqueeze(0).to(device)
|
55 |
-
output = model(input)
|
56 |
logits = output.logits[0,-1]
|
57 |
-
|
58 |
-
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# IF no EOS is generated, return after the max_len
|
66 |
-
|
67 |
-
|
68 |
-
def predict(text, model, tokenizer):
|
69 |
-
result_text = []
|
70 |
-
for i in range(6):
|
71 |
-
summary = model_infer(model, tokenizer, text+"TL;DR").strip()
|
72 |
-
result_text.append(summary[len(text)+5:])
|
73 |
return sorted(result_text, key=len)[3]
|
74 |
-
#print("summary:", sorted(result_text, key=len)[3])
|
75 |
|
76 |
class EndpointHandler():
|
77 |
def __init__(self, path=""):
|
@@ -80,17 +76,9 @@ class EndpointHandler():
|
|
80 |
self.model = AutoModelForCausalLM.from_pretrained("Lin0He/text-summary-gpt2-short")
|
81 |
|
82 |
|
83 |
-
def __call__(self, data: Dict[str, Any]) ->
|
84 |
# process input
|
85 |
inputs = data.pop("inputs", data)
|
86 |
# process input text
|
87 |
-
prediction =
|
88 |
return prediction
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
'''
|
93 |
-
predictor = pipeline("summarization", model = model, tokenizer = tokenizer)
|
94 |
-
result = predictor("Input text for prediction")
|
95 |
-
print(result)
|
96 |
-
'''
|
|
|
31 |
return int(tokenId)
|
32 |
|
33 |
def model_infer(model, tokenizer, review, max_length=300):
|
34 |
+
result_text = []
|
35 |
+
for i in range(6):
|
36 |
+
|
37 |
+
# Preprocess the init token (task designator)
|
38 |
+
review_encoded = tokenizer.encode(review)
|
39 |
+
result = review_encoded
|
40 |
+
initial_input = torch.tensor(review_encoded).unsqueeze(0).to(device)
|
41 |
+
|
42 |
+
with torch.set_grad_enabled(False):
|
43 |
+
# Feed the init token to the model
|
44 |
+
output = model(initial_input)
|
45 |
+
|
46 |
+
# Flatten the logits at the final time step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
logits = output.logits[0,-1]
|
48 |
+
|
49 |
+
# Make a top-k choice and append to the result
|
50 |
+
#choices = [topk(logits) for i in range(5)]
|
51 |
+
choices = topk(logits)
|
52 |
+
result.append(choices)
|
53 |
+
|
54 |
+
# For max_length times:
|
55 |
+
for _ in range(max_length):
|
56 |
+
# Feed the current sequence to the model and make a choice
|
57 |
+
input = torch.tensor(result).unsqueeze(0).to(device)
|
58 |
+
output = model(input)
|
59 |
+
logits = output.logits[0,-1]
|
60 |
+
res_id = topk(logits)
|
61 |
+
|
62 |
+
# If the chosen token is EOS, return the result
|
63 |
+
if res_id == tokenizer.eos_token_id:
|
64 |
+
return tokenizer.decode(result)
|
65 |
+
else: # Append to the sequence
|
66 |
+
result.append(res_id)
|
67 |
|
68 |
# IF no EOS is generated, return after the max_len
|
69 |
+
result_text.append(tokenizer.decode(result))
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
return sorted(result_text, key=len)[3]
|
|
|
71 |
|
72 |
class EndpointHandler():
|
73 |
def __init__(self, path=""):
|
|
|
76 |
self.model = AutoModelForCausalLM.from_pretrained("Lin0He/text-summary-gpt2-short")
|
77 |
|
78 |
|
79 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
80 |
# process input
|
81 |
inputs = data.pop("inputs", data)
|
82 |
# process input text
|
83 |
+
prediction = model_infer( self.model, self.tokenizer,inputs+"TL;DR")
|
84 |
return prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|