Lin0He commited on
Commit
279b6de
·
1 Parent(s): bdb2bc4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -4
handler.py CHANGED
@@ -11,7 +11,9 @@ tokenizer.push_to_hub(repo_name="text-summary-gpt2-short", repo_id="Lin0He/text-
11
  '''
12
  import torch
13
  from typing import Dict, List, Any
14
- from transformers import pipeline, AutoModel, AutoTokenizer
 
 
15
 
16
  def topk(probs, n=9):
17
  # The scores are initially softmaxed to convert to probabilities
@@ -65,8 +67,8 @@ def model_infer(model, tokenizer, review, max_length=300):
65
  def predict(text, model, tokenizer):
66
  result_text = []
67
  for i in range(6):
68
- summary = model_infer(model, tokenizer, input+"TL;DR").strip()
69
- result_text.append(summary[len(input)+5:])
70
  return sorted(result_text, key=len)[3]
71
  #print("summary:", sorted(result_text, key=len)[3])
72
 
@@ -74,7 +76,7 @@ class EndpointHandler():
74
  def __init__(self, path="Lin0He/text-summary-gpt2-short"):
75
  # load model and tokenizer from path
76
  self.tokenizer = AutoTokenizer.from_pretrained(path)
77
- self.model = AutoModel.from_pretrained(path)
78
 
79
 
80
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
11
  '''
12
  import torch
13
  from typing import Dict, List, Any
14
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
 
18
  def topk(probs, n=9):
19
  # The scores are initially softmaxed to convert to probabilities
 
67
  def predict(text, model, tokenizer):
68
  result_text = []
69
  for i in range(6):
70
+ summary = model_infer(model, tokenizer, text+"TL;DR").strip()
71
+ result_text.append(summary[len(text)+5:])
72
  return sorted(result_text, key=len)[3]
73
  #print("summary:", sorted(result_text, key=len)[3])
74
 
 
76
  def __init__(self, path="Lin0He/text-summary-gpt2-short"):
77
  # load model and tokenizer from path
78
  self.tokenizer = AutoTokenizer.from_pretrained(path)
79
+ self.model = AutoModelForCausalLM.from_pretrained(path)
80
 
81
 
82
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: