prlabs2023 commited on
Commit
2937f03
·
1 Parent(s): 0c810f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -28,19 +28,29 @@ from fastapi import Form
28
  class Query(BaseModel):
29
  text: str
30
 
31
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
- tokenizer = BertTokenizerFast.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization')
33
- model = EncoderDecoderModel.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization').to(device)
 
 
 
 
 
 
 
34
 
35
  def generate_summary(text):
 
 
 
36
  # cut off at BERT max length 512
37
- inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
38
- input_ids = inputs.input_ids.to(device)
39
- attention_mask = inputs.attention_mask.to(device)
40
 
41
- output = model.generate(input_ids, attention_mask=attention_mask)
42
 
43
- return tokenizer.decode(output[0], skip_special_tokens=True)
44
 
45
 
46
  from fastapi import FastAPI, Request, Depends, UploadFile, File
@@ -93,6 +103,4 @@ async def get_answer(q: Query ):
93
 
94
 
95
  return "hello"
96
-
97
-
98
-
 
28
  class Query(BaseModel):
29
  text: str
30
 
31
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ # tokenizer = BertTokenizerFast.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization')
33
+ # model = EncoderDecoderModel.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization').to(device)
34
+
35
+ summarizer = pipeline(
36
+ "summarization",
37
+ "pszemraj/long-t5-tglobal-base-16384-book-summary",
38
+ device=0 if torch.cuda.is_available() else -1,
39
+ )
40
+
41
 
42
  def generate_summary(text):
43
+
44
+ result = summarizer(text)
45
+ return result[0]["summary_text"]
46
  # cut off at BERT max length 512
47
+ # inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
48
+ # input_ids = inputs.input_ids.to(device)
49
+ # attention_mask = inputs.attention_mask.to(device)
50
 
51
+ # output = model.generate(input_ids, attention_mask=attention_mask)
52
 
53
+ # return tokenizer.decode(output[0], skip_special_tokens=True)
54
 
55
 
56
  from fastapi import FastAPI, Request, Depends, UploadFile, File
 
103
 
104
 
105
  return "hello"
106
+