Spaces:
Runtime error
Runtime error
Commit
·
2937f03
1
Parent(s):
0c810f3
Update app.py
Browse files
app.py
CHANGED
@@ -28,19 +28,29 @@ from fastapi import Form
|
|
28 |
class Query(BaseModel):
|
29 |
text: str
|
30 |
|
31 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
-
tokenizer = BertTokenizerFast.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization')
|
33 |
-
model = EncoderDecoderModel.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization').to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def generate_summary(text):
|
|
|
|
|
|
|
36 |
# cut off at BERT max length 512
|
37 |
-
inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
|
38 |
-
input_ids = inputs.input_ids.to(device)
|
39 |
-
attention_mask = inputs.attention_mask.to(device)
|
40 |
|
41 |
-
output = model.generate(input_ids, attention_mask=attention_mask)
|
42 |
|
43 |
-
return tokenizer.decode(output[0], skip_special_tokens=True)
|
44 |
|
45 |
|
46 |
from fastapi import FastAPI, Request, Depends, UploadFile, File
|
@@ -93,6 +103,4 @@ async def get_answer(q: Query ):
|
|
93 |
|
94 |
|
95 |
return "hello"
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
28 |
class Query(BaseModel):
|
29 |
text: str
|
30 |
|
31 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
# tokenizer = BertTokenizerFast.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization')
|
33 |
+
# model = EncoderDecoderModel.from_pretrained('mrm8488/bert-small2bert-small-finetuned-cnn_daily_mail-summarization').to(device)
|
34 |
+
|
35 |
+
summarizer = pipeline(
|
36 |
+
"summarization",
|
37 |
+
"pszemraj/long-t5-tglobal-base-16384-book-summary",
|
38 |
+
device=0 if torch.cuda.is_available() else -1,
|
39 |
+
)
|
40 |
+
|
41 |
|
42 |
def generate_summary(text):
|
43 |
+
|
44 |
+
result = summarizer(text)
|
45 |
+
return result[0]["summary_text"]
|
46 |
# cut off at BERT max length 512
|
47 |
+
# inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
|
48 |
+
# input_ids = inputs.input_ids.to(device)
|
49 |
+
# attention_mask = inputs.attention_mask.to(device)
|
50 |
|
51 |
+
# output = model.generate(input_ids, attention_mask=attention_mask)
|
52 |
|
53 |
+
# return tokenizer.decode(output[0], skip_special_tokens=True)
|
54 |
|
55 |
|
56 |
from fastapi import FastAPI, Request, Depends, UploadFile, File
|
|
|
103 |
|
104 |
|
105 |
return "hello"
|
106 |
+
|
|
|
|