Lin0He commited on
Commit
dffcf5a
·
1 Parent(s): 9d9d3b4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -47
handler.py CHANGED
@@ -31,47 +31,43 @@ def topk(probs, n=9):
31
  return int(tokenId)
32
 
33
  def model_infer(model, tokenizer, review, max_length=300):
34
- # Preprocess the init token (task designator)
35
- review_encoded = tokenizer.encode(review)
36
- result = review_encoded
37
- initial_input = torch.tensor(review_encoded).unsqueeze(0).to(device)
38
-
39
- with torch.set_grad_enabled(False):
40
- # Feed the init token to the model
41
- output = model(initial_input)
42
-
43
- # Flatten the logits at the final time step
44
- logits = output.logits[0,-1]
45
-
46
- # Make a top-k choice and append to the result
47
- #choices = [topk(logits) for i in range(5)]
48
- choices = topk(logits)
49
- result.append(choices)
50
-
51
- # For max_length times:
52
- for _ in range(max_length):
53
- # Feed the current sequence to the model and make a choice
54
- input = torch.tensor(result).unsqueeze(0).to(device)
55
- output = model(input)
56
  logits = output.logits[0,-1]
57
- res_id = topk(logits)
58
-
59
- # If the chosen token is EOS, return the result
60
- if res_id == tokenizer.eos_token_id:
61
- return tokenizer.decode(result)
62
- else: # Append to the sequence
63
- result.append(res_id)
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # IF no EOS is generated, return after the max_len
66
- return tokenizer.decode(result)
67
-
68
- def predict(text, model, tokenizer):
69
- result_text = []
70
- for i in range(6):
71
- summary = model_infer(model, tokenizer, text+"TL;DR").strip()
72
- result_text.append(summary[len(text)+5:])
73
  return sorted(result_text, key=len)[3]
74
- #print("summary:", sorted(result_text, key=len)[3])
75
 
76
  class EndpointHandler():
77
  def __init__(self, path=""):
@@ -80,17 +76,9 @@ class EndpointHandler():
80
  self.model = AutoModelForCausalLM.from_pretrained("Lin0He/text-summary-gpt2-short")
81
 
82
 
83
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
84
  # process input
85
  inputs = data.pop("inputs", data)
86
  # process input text
87
- prediction = predict(inputs, self.model, self.tokenizer)
88
  return prediction
89
-
90
-
91
-
92
- '''
93
- predictor = pipeline("summarization", model = model, tokenizer = tokenizer)
94
- result = predictor("Input text for prediction")
95
- print(result)
96
- '''
 
31
  return int(tokenId)
32
 
33
  def model_infer(model, tokenizer, review, max_length=300):
34
+ result_text = []
35
+ for i in range(6):
36
+
37
+ # Preprocess the init token (task designator)
38
+ review_encoded = tokenizer.encode(review)
39
+ result = review_encoded
40
+ initial_input = torch.tensor(review_encoded).unsqueeze(0).to(device)
41
+
42
+ with torch.set_grad_enabled(False):
43
+ # Feed the init token to the model
44
+ output = model(initial_input)
45
+
46
+ # Flatten the logits at the final time step
 
 
 
 
 
 
 
 
 
47
  logits = output.logits[0,-1]
48
+
49
+ # Make a top-k choice and append to the result
50
+ #choices = [topk(logits) for i in range(5)]
51
+ choices = topk(logits)
52
+ result.append(choices)
53
+
54
+ # For max_length times:
55
+ for _ in range(max_length):
56
+ # Feed the current sequence to the model and make a choice
57
+ input = torch.tensor(result).unsqueeze(0).to(device)
58
+ output = model(input)
59
+ logits = output.logits[0,-1]
60
+ res_id = topk(logits)
61
+
62
+ # If the chosen token is EOS, return the result
63
+ if res_id == tokenizer.eos_token_id:
64
+ return tokenizer.decode(result)
65
+ else: # Append to the sequence
66
+ result.append(res_id)
67
 
68
  # IF no EOS is generated, return after the max_len
69
+ result_text.append(tokenizer.decode(result))
 
 
 
 
 
 
70
  return sorted(result_text, key=len)[3]
 
71
 
72
  class EndpointHandler():
73
  def __init__(self, path=""):
 
76
  self.model = AutoModelForCausalLM.from_pretrained("Lin0He/text-summary-gpt2-short")
77
 
78
 
79
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
80
  # process input
81
  inputs = data.pop("inputs", data)
82
  # process input text
83
+ prediction = model_infer( self.model, self.tokenizer,inputs+"TL;DR")
84
  return prediction